Python · 12217 bytes Raw Blame History
1 """Cross-check adapter embedding rows against base GGUF rows.
2
3 The pad-fallback path sets
4 `modules_to_save=["embed_tokens","lm_head"]` so the LoRA adapter
5 carries its own trained embedding / lm-head rows. At export time the
6 base GGUF's corresponding rows must match byte-for-byte, or the
7 adapter's added-token embeddings end up multiplied against
8 uninitialized base rows — the "gibberish on `<|pad|>`" failure mode.
9
10 Contract:
11
12 - Runs only when `adapter_config.json::modules_to_save` includes
13 either `embed_tokens` or `lm_head`. Otherwise nothing to compare.
14 - Hashes the added-special-token rows from both sides (PEFT
15 safetensors, base GGUF) and compares.
16 - Skips cleanly when the adapter tokenizer has no added specials
17 (nothing changed; base rows are authoritative).
18 - Skips cleanly when the base GGUF's embedding is block-quantized
19 (re-quantize-after-the-fact pipeline; row-level check impossible)
20 with an explanatory `PreflightError.probe="embedding_row_sha"`
21 detail. Kept informational rather than failing until we see a real
22 user hit it.
23
24 The per-architecture tensor-name map is small — llama-family /
25 chatml / phi3 / mistral all use the same convention
26 (`token_embd.weight` / `output.weight`), so one entry covers our v1
27 registry. Future archs can extend `_ARCH_TENSOR_NAMES` without
28 touching the assertion logic.
29 """
30
31 from __future__ import annotations
32
33 import hashlib
34 import json
35 from dataclasses import dataclass
36 from pathlib import Path
37 from typing import Any, Final
38
39 from dlm.export.errors import PreflightError
40 from dlm.export.gguf_tensors import _SCALAR_BYTES, load_tensor_index
41
42
43 @dataclass(frozen=True)
44 class _TensorMap:
45 """Maps a dlm logical module name to GGUF + PEFT name fragments."""
46
47 gguf_name: str
48 safetensors_suffix: str
49
50
51 # dlm-internal logical name → (GGUF tensor name, PEFT safetensors key suffix).
52 # `modules_to_save` layers are written under
53 # `base_model.model.<path>.modules_to_save.default.weight` by PEFT.
54 # The <path> prefix varies per arch but the trailing suffix is stable;
55 # we scan the file's key index for the suffix rather than hard-coding
56 # the full path.
57 _DEFAULT_TENSOR_MAP: Final[dict[str, _TensorMap]] = {
58 "embed_tokens": _TensorMap(
59 gguf_name="token_embd.weight",
60 safetensors_suffix="embed_tokens.modules_to_save.default.weight",
61 ),
62 "lm_head": _TensorMap(
63 gguf_name="output.weight",
64 safetensors_suffix="lm_head.modules_to_save.default.weight",
65 ),
66 }
67
68
69 def assert_embedding_rows_match(
70 adapter_dir: Path,
71 base_gguf: Path,
72 ) -> None:
73 """Verify added-token rows agree between adapter safetensors and base GGUF.
74
75 Skip conditions (no-raise):
76
77 - `adapter_config.json` missing or unreadable — export preflight's
78 `check_adapter_config` already owns that error path.
79 - `modules_to_save` absent / doesn't include embed_tokens or
80 lm_head — adapter doesn't own any embedding rows.
81 - Tokenizer config has no added special tokens — nothing to
82 check; base rows are authoritative.
83
84 Raises `PreflightError(probe="embedding_row_sha")` on a real
85 mismatch or on a file-level corruption (missing safetensors,
86 absent GGUF tensor, dtype unsupported).
87 """
88 cfg_path = adapter_dir / "adapter_config.json"
89 if not cfg_path.is_file():
90 return
91 try:
92 cfg = json.loads(cfg_path.read_text(encoding="utf-8"))
93 except (OSError, json.JSONDecodeError):
94 # Export preflight owns the "unreadable adapter config"
95 # error path; we silently opt out here rather than double-report.
96 return
97
98 saved_modules = cfg.get("modules_to_save") or []
99 if not isinstance(saved_modules, list):
100 return
101 applicable = [m for m in saved_modules if m in _DEFAULT_TENSOR_MAP]
102 if not applicable:
103 return
104
105 added_token_ids = _added_special_token_ids(adapter_dir)
106 if not added_token_ids:
107 return
108
109 # Load the base GGUF's tensor index once; reused across modules.
110 index = load_tensor_index(base_gguf)
111
112 # Load the adapter safetensors once; reused across modules.
113 safetensors_data = _load_adapter_safetensors(adapter_dir)
114
115 mismatches: list[str] = []
116 for module in applicable:
117 tmap = _DEFAULT_TENSOR_MAP[module]
118 adapter_tensor = _find_module_tensor(safetensors_data, tmap.safetensors_suffix)
119 if adapter_tensor is None:
120 # modules_to_save declared but the tensor isn't in the file —
121 # surface as a real error; adapter is malformed.
122 raise PreflightError(
123 probe="embedding_row_sha",
124 detail=(
125 f"adapter declares modules_to_save={module!r} but "
126 f"{tmap.safetensors_suffix} is absent from adapter "
127 f"safetensors in {adapter_dir}"
128 ),
129 )
130 gguf_entry = index.find(tmap.gguf_name)
131 if gguf_entry is None:
132 raise PreflightError(
133 probe="embedding_row_sha",
134 detail=(
135 f"base GGUF {base_gguf.name} is missing tensor "
136 f"{tmap.gguf_name!r}; cannot verify adapter "
137 f"modules_to_save={module!r}"
138 ),
139 )
140 if gguf_entry.dtype not in _SCALAR_BYTES:
141 # The base was quantized to a block-quantized type (user
142 # ran the flow with a non-default quant that touches the
143 # embedding). We can't read rows. Surface as a preflight
144 # error so the operator knows the check didn't run rather
145 # than silently passing.
146 raise PreflightError(
147 probe="embedding_row_sha",
148 detail=(
149 f"{tmap.gguf_name!r} in {base_gguf.name} is "
150 f"block-quantized (ggml dtype {gguf_entry.dtype}); "
151 "re-export with embedding tensors left at F16 or "
152 "disable the embedding_rows checker explicitly."
153 ),
154 )
155
156 adapter_rows = _as_row_list(adapter_tensor)
157 for tid in added_token_ids:
158 if tid < 0 or tid >= len(adapter_rows):
159 # Added token id is out of the adapter-tensor vocab
160 # range. Either the tokenizer added tokens without
161 # resizing embed_tokens, or the safetensors shape is stale.
162 raise PreflightError(
163 probe="embedding_row_sha",
164 detail=(
165 f"added token id {tid} is out of range for "
166 f"{module} (adapter has {len(adapter_rows)} rows)"
167 ),
168 )
169 adapter_sha = hashlib.sha256(adapter_rows[tid]).hexdigest()
170 # `index.row_bytes` raises `PreflightError` on dtype mismatch /
171 # out-of-range — those bubble up naturally; no catch-and-rethrow.
172 base_row = index.row_bytes(tmap.gguf_name, tid)
173 base_sha = hashlib.sha256(base_row).hexdigest()
174 if adapter_sha != base_sha:
175 mismatches.append(
176 f"{module}[{tid}]: adapter={adapter_sha[:12]}… base={base_sha[:12]}…"
177 )
178
179 if mismatches:
180 raise PreflightError(
181 probe="embedding_row_sha",
182 detail=(
183 "adapter embedding rows disagree with base GGUF for "
184 f"{len(mismatches)} added token(s): {'; '.join(mismatches)}. "
185 "The base was regenerated against a different tokenizer; "
186 "re-run `dlm export` with a fresh base conversion."
187 ),
188 )
189
190
191 # --- internals ----------------------------------------------------------------
192
193
194 def _added_special_token_ids(adapter_dir: Path) -> list[int]:
195 """Return the sorted list of added-special-token ids, or `[]`."""
196 cfg_path = adapter_dir / "tokenizer_config.json"
197 if not cfg_path.is_file():
198 return []
199 try:
200 cfg = json.loads(cfg_path.read_text(encoding="utf-8"))
201 except (OSError, json.JSONDecodeError):
202 return []
203
204 added = cfg.get("added_tokens_decoder") or {}
205 if not isinstance(added, dict):
206 return []
207
208 ids: list[int] = []
209 for key, entry in added.items():
210 if not isinstance(entry, dict):
211 continue
212 if entry.get("special") is not True:
213 continue
214 try:
215 tid = int(key)
216 except (TypeError, ValueError):
217 continue
218 ids.append(tid)
219 return sorted(set(ids))
220
221
222 def _load_adapter_safetensors(adapter_dir: Path) -> Any:
223 """Open the PEFT `adapter_model.safetensors` for lazy tensor access.
224
225 Returns the `safetensors.safe_open` handle (context-managed by the
226 caller) OR a dict of {key: numpy-array} if we need to materialize
227 for row extraction. For simplicity we materialize the two
228 modules_to_save tensors eagerly when present — they're the biggest
229 tensors in a modules_to_save adapter by far, but still only one
230 embedding matrix's worth, which is bounded by
231 `vocab_size * hidden * dtype_bytes` (~100 MB worst-case for our
232 launch registry).
233 """
234 path = adapter_dir / "adapter_model.safetensors"
235 if not path.is_file():
236 raise PreflightError(
237 probe="embedding_row_sha",
238 detail=(
239 f"adapter_model.safetensors not found in {adapter_dir}; "
240 "PEFT writes this on save_pretrained — the adapter may "
241 "have been interrupted mid-save"
242 ),
243 )
244
245 from safetensors import safe_open
246
247 materialized: dict[str, Any] = {}
248 try:
249 with safe_open(str(path), framework="numpy") as handle: # type: ignore[no-untyped-call]
250 for key in handle.keys(): # noqa: SIM118 — safetensors API
251 # Only materialize the two modules_to_save shapes we
252 # care about; LoRA A/B matrices are smaller but numerous.
253 if key.endswith(
254 (
255 "embed_tokens.modules_to_save.default.weight",
256 "lm_head.modules_to_save.default.weight",
257 )
258 ):
259 materialized[key] = handle.get_tensor(key)
260 except OSError as exc:
261 raise PreflightError(
262 probe="embedding_row_sha",
263 detail=f"cannot read adapter safetensors at {path}: {exc}",
264 ) from exc
265 return materialized
266
267
268 def _find_module_tensor(safetensors_data: Any, suffix: str) -> Any:
269 """Return the safetensors entry whose key ends with `suffix`, or None.
270
271 PEFT prefixes the key with `base_model.model.` or
272 `base_model.model.model.` depending on whether the base model has
273 a `model` submodule; matching on suffix sidesteps that variation.
274 """
275 for key, tensor in safetensors_data.items():
276 if key.endswith(suffix):
277 return tensor
278 return None
279
280
281 def _as_row_list(tensor: Any) -> list[bytes]:
282 """Turn a numpy-like tensor into a per-row `bytes` list.
283
284 The PEFT-saved embedding is (vocab_size, hidden). We slice row i
285 to bytes of that single row's elements — then sha256 compares
286 row-by-row without materializing the whole matrix twice.
287
288 We don't convert dtypes here; the comparison is byte-level on both
289 sides, and `convert_hf_to_gguf.py` writes the embedding in the
290 tensor's native dtype (F16 → F16, F32 → F32). If the adapter was
291 saved in BF16 but the base GGUF ended up as F16, the bytes differ
292 even for mathematically-equal values — that's a real mismatch we
293 want to flag (silent dtype conversions break inference).
294 """
295 import numpy as np
296
297 arr = np.asarray(tensor)
298 if arr.ndim < 2:
299 # embed_tokens / lm_head are always 2D; a 1D tensor means a
300 # shape mismatch. Return empty to cause the caller's bounds
301 # check to surface the problem.
302 return []
303 # Slice ensures contiguity; tobytes on a non-contiguous row would
304 # silently re-pack and mask dtype drift.
305 rows: list[bytes] = []
306 for i in range(arr.shape[0]):
307 row = np.ascontiguousarray(arr[i])
308 rows.append(bytes(row.tobytes()))
309 return rows