Python · 10481 bytes Raw Blame History
1 """Importance-matrix (imatrix) calibrated quantization.
2
3 Static k-quants lose 1–3 perplexity points on domain-specific vocabulary
4 vs. fp16; an importance matrix built from text the model will actually
5 generate shrinks that gap well under 1 pp (llama.cpp's upstream
6 benchmarks). The replay corpus is the natural calibration source —
7 it's what the adapter was trained on, already on disk, per-document,
8 and SHA-addressable.
9
10 Module contract:
11
12 - `ImatrixArtifact` — dataclass bundling (path, sha256, metadata).
13 - `resolve_imatrix(export_dir, base_revision, corpus_sha256, chunks)`
14 → returns an existing matching artifact or `None`.
15 - `build_imatrix(base_gguf, calibration_text, out_path, ...)` → runs
16 the `llama-imatrix` subprocess and returns an `ImatrixArtifact`.
17 - `calibration_text_from_replay(replay_store, max_chars=...)` →
18 concatenates sampled replay-corpus prose + instruction text into a
19 single calibration string. Deterministic given the same store.
20
21 Cache key: `(base_revision, corpus_sha256, chunks)`. A fresh replay
22 write changes the corpus sha and rebuilds automatically; re-exporting
23 a second quant against the same corpus hits the cache.
24
25 Subprocess: `llama-imatrix -m <base.gguf> -f <calib.txt> -o <out>
26 --chunks N`. Calibration text lands in `<export_dir>/imatrix.calib.txt`
27 and is removed after the subprocess exits (keeps the store lean; the
28 cache sidecar records just the `corpus_sha256`).
29 """
30
31 from __future__ import annotations
32
33 import hashlib
34 import json
35 import logging
36 from collections.abc import Callable, Sequence
37 from dataclasses import dataclass
38 from datetime import UTC, datetime
39 from pathlib import Path
40 from typing import Any, Final
41
42 from dlm.export import vendoring
43 from dlm.export.errors import SubprocessError
44 from dlm.export.quantize import run_checked
45
46 _LOG = logging.getLogger(__name__)
47
48 _IMATRIX_FILENAME: Final[str] = "imatrix.gguf"
49 _IMATRIX_META_FILENAME: Final[str] = "imatrix.meta.json"
50 _CALIB_FILENAME: Final[str] = "imatrix.calib.txt"
51
52 # Calibration defaults — `llama-imatrix` crashes on a too-short calib
53 # file, so we enforce a reasonable floor. These map to ~128k tokens
54 # with a typical BPE tokenizer, which is enough to stabilize
55 # per-tensor statistics without burning minutes on CPU.
56 DEFAULT_CHUNKS: Final[int] = 256
57 DEFAULT_CHUNK_SIZE: Final[int] = 512
58
59 # Max characters we concatenate from replay before truncating. The
60 # imatrix binary reads `--chunks × --chunk-size` tokens; a 4x text
61 # overhead factor covers typical ratios and gives headroom.
62 _CALIB_CHAR_HEADROOM: Final[int] = 4
63
64
65 @dataclass(frozen=True)
66 class ImatrixArtifact:
67 """Materialized imatrix file plus enough metadata to validate a cache hit."""
68
69 path: Path
70 sha256: str
71 base_revision: str
72 corpus_sha256: str
73 chunks: int
74 chunk_size: int
75 built_at: datetime
76
77 def to_meta_dict(self) -> dict[str, Any]:
78 return {
79 "path": self.path.name,
80 "sha256": self.sha256,
81 "base_revision": self.base_revision,
82 "corpus_sha256": self.corpus_sha256,
83 "chunks": self.chunks,
84 "chunk_size": self.chunk_size,
85 "built_at": self.built_at.isoformat(),
86 }
87
88
89 def resolve_imatrix(
90 export_dir: Path,
91 *,
92 base_revision: str,
93 corpus_sha256: str,
94 chunks: int,
95 chunk_size: int = DEFAULT_CHUNK_SIZE,
96 ) -> ImatrixArtifact | None:
97 """Return the cached artifact if its key matches, else `None`.
98
99 Cache hit requires all five match: binary present, sidecar parse
100 succeeds, and (`base_revision`, `corpus_sha256`, `chunks`,
101 `chunk_size`) equal. Any divergence → treat as miss; caller
102 rebuilds.
103 """
104 bin_path = export_dir / _IMATRIX_FILENAME
105 meta_path = export_dir / _IMATRIX_META_FILENAME
106 if not bin_path.is_file() or not meta_path.is_file():
107 return None
108 try:
109 meta = json.loads(meta_path.read_text(encoding="utf-8"))
110 except (OSError, json.JSONDecodeError):
111 return None
112 if not isinstance(meta, dict):
113 return None
114 if (
115 meta.get("base_revision") != base_revision
116 or meta.get("corpus_sha256") != corpus_sha256
117 or meta.get("chunks") != chunks
118 or meta.get("chunk_size") != chunk_size
119 ):
120 return None
121 recorded_sha = meta.get("sha256")
122 if not isinstance(recorded_sha, str):
123 return None
124 # Verify the binary itself hasn't been touched; a stale sidecar
125 # with a tampered binary would otherwise look like a cache hit.
126 actual_sha = _sha256_of_file(bin_path)
127 if actual_sha != recorded_sha:
128 return None
129 built_at_raw = meta.get("built_at")
130 try:
131 built_at = datetime.fromisoformat(str(built_at_raw))
132 except (TypeError, ValueError):
133 return None
134 return ImatrixArtifact(
135 path=bin_path,
136 sha256=actual_sha,
137 base_revision=base_revision,
138 corpus_sha256=corpus_sha256,
139 chunks=chunks,
140 chunk_size=chunk_size,
141 built_at=built_at,
142 )
143
144
145 def build_imatrix(
146 *,
147 base_gguf: Path,
148 calibration_text: str,
149 export_dir: Path,
150 base_revision: str,
151 corpus_sha256: str,
152 chunks: int = DEFAULT_CHUNKS,
153 chunk_size: int = DEFAULT_CHUNK_SIZE,
154 bin_override: Path | None = None,
155 subprocess_runner: Callable[[Sequence[str]], Any] | None = None,
156 ) -> ImatrixArtifact:
157 """Run `llama-imatrix` against `calibration_text`; write the binary + sidecar.
158
159 Writes `imatrix.gguf` and `imatrix.meta.json` atomically into
160 `export_dir`. On subprocess failure the calibration text file is
161 left behind for debugging; on success it's removed.
162 """
163 if not base_gguf.is_file():
164 raise FileNotFoundError(f"imatrix base model missing: {base_gguf}")
165 if chunks <= 0:
166 raise ValueError(f"chunks must be positive, got {chunks}")
167 if chunk_size <= 0:
168 raise ValueError(f"chunk_size must be positive, got {chunk_size}")
169 if not calibration_text.strip():
170 raise ValueError("calibration_text is empty; imatrix needs real text")
171
172 export_dir.mkdir(parents=True, exist_ok=True)
173 calib_path = export_dir / _CALIB_FILENAME
174 out_path = export_dir / _IMATRIX_FILENAME
175 meta_path = export_dir / _IMATRIX_META_FILENAME
176
177 calib_path.write_text(calibration_text, encoding="utf-8")
178
179 run = subprocess_runner if subprocess_runner is not None else run_checked
180 argv = build_imatrix_args(
181 base_gguf=base_gguf,
182 calib_path=calib_path,
183 out_path=out_path,
184 chunks=chunks,
185 bin_override=bin_override,
186 )
187
188 _LOG.info("imatrix: building (%d × %d tokens)", chunks, chunk_size)
189 try:
190 run(argv)
191 except SubprocessError:
192 # Leave calib file on disk so operators can rerun by hand.
193 raise
194
195 if not out_path.is_file():
196 raise SubprocessError(
197 cmd=[str(a) for a in argv],
198 returncode=0,
199 stderr_tail=(f"llama-imatrix exited 0 but {out_path.name} was not produced."),
200 )
201
202 sha = _sha256_of_file(out_path)
203 artifact = ImatrixArtifact(
204 path=out_path,
205 sha256=sha,
206 base_revision=base_revision,
207 corpus_sha256=corpus_sha256,
208 chunks=chunks,
209 chunk_size=chunk_size,
210 built_at=datetime.now(UTC).replace(tzinfo=None, microsecond=0),
211 )
212 meta_path.write_text(
213 json.dumps(artifact.to_meta_dict(), indent=2) + "\n",
214 encoding="utf-8",
215 )
216 # Calib text is regenerable from the replay corpus; don't leave
217 # the concatenated blob on disk.
218 calib_path.unlink(missing_ok=True)
219 return artifact
220
221
222 def build_imatrix_args(
223 *,
224 base_gguf: Path,
225 calib_path: Path,
226 out_path: Path,
227 chunks: int,
228 bin_override: Path | None = None,
229 ) -> list[str]:
230 """Assemble the `llama-imatrix ...` argv.
231
232 Pure string-manipulation; no subprocess, no FS side effects beyond
233 the vendoring resolver. Snapshot-tested against the pinned
234 upstream CLI shape (audit F09 pattern).
235 """
236 binary = vendoring.llama_imatrix_bin(bin_override)
237 return [
238 str(binary),
239 "-m",
240 str(base_gguf),
241 "-f",
242 str(calib_path),
243 "-o",
244 str(out_path),
245 "--chunks",
246 str(chunks),
247 ]
248
249
250 def calibration_text_from_replay(
251 corpus_path: Path,
252 index_path: Path,
253 *,
254 max_chars: int = DEFAULT_CHUNKS * DEFAULT_CHUNK_SIZE * _CALIB_CHAR_HEADROOM,
255 ) -> tuple[str, str]:
256 """Concatenate replay-corpus snapshot content into one calibration blob.
257
258 Returns `(text, corpus_sha256)`. `corpus_sha256` hashes the raw
259 `corpus.zst` bytes so the cache key tracks any corpus mutation
260 (training a new adapter-version appends to the file → new sha).
261
262 A missing / empty corpus yields `("", "<no-corpus>")` — callers
263 decide whether to fall back to static quantization in that case.
264 """
265 if not corpus_path.is_file():
266 return "", "<no-corpus>"
267
268 corpus_sha = _sha256_of_file(corpus_path)
269
270 if not index_path.is_file():
271 # Corpus exists but no index — same story as no-corpus for
272 # calibration purposes; the sha still tracks the binary so a
273 # subsequent index write triggers a rebuild.
274 return "", corpus_sha
275
276 # Walk every snapshot in order. Each snapshot's raw content is
277 # plain UTF-8 text the imatrix binary re-tokenizes; we don't
278 # involve a Python tokenizer here.
279 from dlm.replay.corpus import iter_snapshots
280 from dlm.replay.index import load_index
281
282 entries = load_index(index_path)
283 buf: list[str] = []
284 total = 0
285 for snap in iter_snapshots(corpus_path, entries):
286 if not snap.content:
287 continue
288 chunk = snap.content.strip()
289 if not chunk:
290 continue
291 if total + len(chunk) > max_chars:
292 remaining = max_chars - total
293 if remaining > 0:
294 buf.append(chunk[:remaining])
295 total += remaining
296 break
297 buf.append(chunk)
298 total += len(chunk)
299
300 return "\n\n".join(buf), corpus_sha
301
302
303 # --- internals ----------------------------------------------------------------
304
305
306 def _sha256_of_file(path: Path, *, chunk_bytes: int = 1 << 20) -> str:
307 hasher = hashlib.sha256()
308 with path.open("rb") as fh:
309 while True:
310 chunk = fh.read(chunk_bytes)
311 if not chunk:
312 break
313 hasher.update(chunk)
314 return hasher.hexdigest()