| 1 |
"""Cross-check adapter embedding rows against base GGUF rows. |
| 2 |
|
| 3 |
The pad-fallback path sets |
| 4 |
`modules_to_save=["embed_tokens","lm_head"]` so the LoRA adapter |
| 5 |
carries its own trained embedding / lm-head rows. At export time the |
| 6 |
base GGUF's corresponding rows must match byte-for-byte, or the |
| 7 |
adapter's added-token embeddings end up multiplied against |
| 8 |
uninitialized base rows — the "gibberish on `<|pad|>`" failure mode. |
| 9 |
|
| 10 |
Contract: |
| 11 |
|
| 12 |
- Runs only when `adapter_config.json::modules_to_save` includes |
| 13 |
either `embed_tokens` or `lm_head`. Otherwise nothing to compare. |
| 14 |
- Hashes the added-special-token rows from both sides (PEFT |
| 15 |
safetensors, base GGUF) and compares. |
| 16 |
- Skips cleanly when the adapter tokenizer has no added specials |
| 17 |
(nothing changed; base rows are authoritative). |
| 18 |
- Skips cleanly when the base GGUF's embedding is block-quantized |
| 19 |
(re-quantize-after-the-fact pipeline; row-level check impossible) |
| 20 |
with an explanatory `PreflightError.probe="embedding_row_sha"` |
| 21 |
detail. Kept informational rather than failing until we see a real |
| 22 |
user hit it. |
| 23 |
|
| 24 |
The per-architecture tensor-name map is small — llama-family / |
| 25 |
chatml / phi3 / mistral all use the same convention |
| 26 |
(`token_embd.weight` / `output.weight`), so one entry covers our v1 |
| 27 |
registry. Future archs can extend `_ARCH_TENSOR_NAMES` without |
| 28 |
touching the assertion logic. |
| 29 |
""" |
| 30 |
|
| 31 |
from __future__ import annotations |
| 32 |
|
| 33 |
import hashlib |
| 34 |
import json |
| 35 |
from dataclasses import dataclass |
| 36 |
from pathlib import Path |
| 37 |
from typing import Any, Final |
| 38 |
|
| 39 |
from dlm.export.errors import PreflightError |
| 40 |
from dlm.export.gguf_tensors import _SCALAR_BYTES, load_tensor_index |
| 41 |
|
| 42 |
|
| 43 |
@dataclass(frozen=True) |
| 44 |
class _TensorMap: |
| 45 |
"""Maps a dlm logical module name to GGUF + PEFT name fragments.""" |
| 46 |
|
| 47 |
gguf_name: str |
| 48 |
safetensors_suffix: str |
| 49 |
|
| 50 |
|
| 51 |
# dlm-internal logical name → (GGUF tensor name, PEFT safetensors key suffix). |
| 52 |
# `modules_to_save` layers are written under |
| 53 |
# `base_model.model.<path>.modules_to_save.default.weight` by PEFT. |
| 54 |
# The <path> prefix varies per arch but the trailing suffix is stable; |
| 55 |
# we scan the file's key index for the suffix rather than hard-coding |
| 56 |
# the full path. |
| 57 |
_DEFAULT_TENSOR_MAP: Final[dict[str, _TensorMap]] = { |
| 58 |
"embed_tokens": _TensorMap( |
| 59 |
gguf_name="token_embd.weight", |
| 60 |
safetensors_suffix="embed_tokens.modules_to_save.default.weight", |
| 61 |
), |
| 62 |
"lm_head": _TensorMap( |
| 63 |
gguf_name="output.weight", |
| 64 |
safetensors_suffix="lm_head.modules_to_save.default.weight", |
| 65 |
), |
| 66 |
} |
| 67 |
|
| 68 |
|
| 69 |
def assert_embedding_rows_match( |
| 70 |
adapter_dir: Path, |
| 71 |
base_gguf: Path, |
| 72 |
) -> None: |
| 73 |
"""Verify added-token rows agree between adapter safetensors and base GGUF. |
| 74 |
|
| 75 |
Skip conditions (no-raise): |
| 76 |
|
| 77 |
- `adapter_config.json` missing or unreadable — export preflight's |
| 78 |
`check_adapter_config` already owns that error path. |
| 79 |
- `modules_to_save` absent / doesn't include embed_tokens or |
| 80 |
lm_head — adapter doesn't own any embedding rows. |
| 81 |
- Tokenizer config has no added special tokens — nothing to |
| 82 |
check; base rows are authoritative. |
| 83 |
|
| 84 |
Raises `PreflightError(probe="embedding_row_sha")` on a real |
| 85 |
mismatch or on a file-level corruption (missing safetensors, |
| 86 |
absent GGUF tensor, dtype unsupported). |
| 87 |
""" |
| 88 |
cfg_path = adapter_dir / "adapter_config.json" |
| 89 |
if not cfg_path.is_file(): |
| 90 |
return |
| 91 |
try: |
| 92 |
cfg = json.loads(cfg_path.read_text(encoding="utf-8")) |
| 93 |
except (OSError, json.JSONDecodeError): |
| 94 |
# Export preflight owns the "unreadable adapter config" |
| 95 |
# error path; we silently opt out here rather than double-report. |
| 96 |
return |
| 97 |
|
| 98 |
saved_modules = cfg.get("modules_to_save") or [] |
| 99 |
if not isinstance(saved_modules, list): |
| 100 |
return |
| 101 |
applicable = [m for m in saved_modules if m in _DEFAULT_TENSOR_MAP] |
| 102 |
if not applicable: |
| 103 |
return |
| 104 |
|
| 105 |
added_token_ids = _added_special_token_ids(adapter_dir) |
| 106 |
if not added_token_ids: |
| 107 |
return |
| 108 |
|
| 109 |
# Load the base GGUF's tensor index once; reused across modules. |
| 110 |
index = load_tensor_index(base_gguf) |
| 111 |
|
| 112 |
# Load the adapter safetensors once; reused across modules. |
| 113 |
safetensors_data = _load_adapter_safetensors(adapter_dir) |
| 114 |
|
| 115 |
mismatches: list[str] = [] |
| 116 |
for module in applicable: |
| 117 |
tmap = _DEFAULT_TENSOR_MAP[module] |
| 118 |
adapter_tensor = _find_module_tensor(safetensors_data, tmap.safetensors_suffix) |
| 119 |
if adapter_tensor is None: |
| 120 |
# modules_to_save declared but the tensor isn't in the file — |
| 121 |
# surface as a real error; adapter is malformed. |
| 122 |
raise PreflightError( |
| 123 |
probe="embedding_row_sha", |
| 124 |
detail=( |
| 125 |
f"adapter declares modules_to_save={module!r} but " |
| 126 |
f"{tmap.safetensors_suffix} is absent from adapter " |
| 127 |
f"safetensors in {adapter_dir}" |
| 128 |
), |
| 129 |
) |
| 130 |
gguf_entry = index.find(tmap.gguf_name) |
| 131 |
if gguf_entry is None: |
| 132 |
raise PreflightError( |
| 133 |
probe="embedding_row_sha", |
| 134 |
detail=( |
| 135 |
f"base GGUF {base_gguf.name} is missing tensor " |
| 136 |
f"{tmap.gguf_name!r}; cannot verify adapter " |
| 137 |
f"modules_to_save={module!r}" |
| 138 |
), |
| 139 |
) |
| 140 |
if gguf_entry.dtype not in _SCALAR_BYTES: |
| 141 |
# The base was quantized to a block-quantized type (user |
| 142 |
# ran the flow with a non-default quant that touches the |
| 143 |
# embedding). We can't read rows. Surface as a preflight |
| 144 |
# error so the operator knows the check didn't run rather |
| 145 |
# than silently passing. |
| 146 |
raise PreflightError( |
| 147 |
probe="embedding_row_sha", |
| 148 |
detail=( |
| 149 |
f"{tmap.gguf_name!r} in {base_gguf.name} is " |
| 150 |
f"block-quantized (ggml dtype {gguf_entry.dtype}); " |
| 151 |
"re-export with embedding tensors left at F16 or " |
| 152 |
"disable the embedding_rows checker explicitly." |
| 153 |
), |
| 154 |
) |
| 155 |
|
| 156 |
adapter_rows = _as_row_list(adapter_tensor) |
| 157 |
for tid in added_token_ids: |
| 158 |
if tid < 0 or tid >= len(adapter_rows): |
| 159 |
# Added token id is out of the adapter-tensor vocab |
| 160 |
# range. Either the tokenizer added tokens without |
| 161 |
# resizing embed_tokens, or the safetensors shape is stale. |
| 162 |
raise PreflightError( |
| 163 |
probe="embedding_row_sha", |
| 164 |
detail=( |
| 165 |
f"added token id {tid} is out of range for " |
| 166 |
f"{module} (adapter has {len(adapter_rows)} rows)" |
| 167 |
), |
| 168 |
) |
| 169 |
adapter_sha = hashlib.sha256(adapter_rows[tid]).hexdigest() |
| 170 |
# `index.row_bytes` raises `PreflightError` on dtype mismatch / |
| 171 |
# out-of-range — those bubble up naturally; no catch-and-rethrow. |
| 172 |
base_row = index.row_bytes(tmap.gguf_name, tid) |
| 173 |
base_sha = hashlib.sha256(base_row).hexdigest() |
| 174 |
if adapter_sha != base_sha: |
| 175 |
mismatches.append( |
| 176 |
f"{module}[{tid}]: adapter={adapter_sha[:12]}… base={base_sha[:12]}…" |
| 177 |
) |
| 178 |
|
| 179 |
if mismatches: |
| 180 |
raise PreflightError( |
| 181 |
probe="embedding_row_sha", |
| 182 |
detail=( |
| 183 |
"adapter embedding rows disagree with base GGUF for " |
| 184 |
f"{len(mismatches)} added token(s): {'; '.join(mismatches)}. " |
| 185 |
"The base was regenerated against a different tokenizer; " |
| 186 |
"re-run `dlm export` with a fresh base conversion." |
| 187 |
), |
| 188 |
) |
| 189 |
|
| 190 |
|
| 191 |
# --- internals ---------------------------------------------------------------- |
| 192 |
|
| 193 |
|
| 194 |
def _added_special_token_ids(adapter_dir: Path) -> list[int]: |
| 195 |
"""Return the sorted list of added-special-token ids, or `[]`.""" |
| 196 |
cfg_path = adapter_dir / "tokenizer_config.json" |
| 197 |
if not cfg_path.is_file(): |
| 198 |
return [] |
| 199 |
try: |
| 200 |
cfg = json.loads(cfg_path.read_text(encoding="utf-8")) |
| 201 |
except (OSError, json.JSONDecodeError): |
| 202 |
return [] |
| 203 |
|
| 204 |
added = cfg.get("added_tokens_decoder") or {} |
| 205 |
if not isinstance(added, dict): |
| 206 |
return [] |
| 207 |
|
| 208 |
ids: list[int] = [] |
| 209 |
for key, entry in added.items(): |
| 210 |
if not isinstance(entry, dict): |
| 211 |
continue |
| 212 |
if entry.get("special") is not True: |
| 213 |
continue |
| 214 |
try: |
| 215 |
tid = int(key) |
| 216 |
except (TypeError, ValueError): |
| 217 |
continue |
| 218 |
ids.append(tid) |
| 219 |
return sorted(set(ids)) |
| 220 |
|
| 221 |
|
| 222 |
def _load_adapter_safetensors(adapter_dir: Path) -> Any: |
| 223 |
"""Open the PEFT `adapter_model.safetensors` for lazy tensor access. |
| 224 |
|
| 225 |
Returns the `safetensors.safe_open` handle (context-managed by the |
| 226 |
caller) OR a dict of {key: numpy-array} if we need to materialize |
| 227 |
for row extraction. For simplicity we materialize the two |
| 228 |
modules_to_save tensors eagerly when present — they're the biggest |
| 229 |
tensors in a modules_to_save adapter by far, but still only one |
| 230 |
embedding matrix's worth, which is bounded by |
| 231 |
`vocab_size * hidden * dtype_bytes` (~100 MB worst-case for our |
| 232 |
launch registry). |
| 233 |
""" |
| 234 |
path = adapter_dir / "adapter_model.safetensors" |
| 235 |
if not path.is_file(): |
| 236 |
raise PreflightError( |
| 237 |
probe="embedding_row_sha", |
| 238 |
detail=( |
| 239 |
f"adapter_model.safetensors not found in {adapter_dir}; " |
| 240 |
"PEFT writes this on save_pretrained — the adapter may " |
| 241 |
"have been interrupted mid-save" |
| 242 |
), |
| 243 |
) |
| 244 |
|
| 245 |
from safetensors import safe_open |
| 246 |
|
| 247 |
materialized: dict[str, Any] = {} |
| 248 |
try: |
| 249 |
with safe_open(str(path), framework="numpy") as handle: # type: ignore[no-untyped-call] |
| 250 |
for key in handle.keys(): # noqa: SIM118 — safetensors API |
| 251 |
# Only materialize the two modules_to_save shapes we |
| 252 |
# care about; LoRA A/B matrices are smaller but numerous. |
| 253 |
if key.endswith( |
| 254 |
( |
| 255 |
"embed_tokens.modules_to_save.default.weight", |
| 256 |
"lm_head.modules_to_save.default.weight", |
| 257 |
) |
| 258 |
): |
| 259 |
materialized[key] = handle.get_tensor(key) |
| 260 |
except OSError as exc: |
| 261 |
raise PreflightError( |
| 262 |
probe="embedding_row_sha", |
| 263 |
detail=f"cannot read adapter safetensors at {path}: {exc}", |
| 264 |
) from exc |
| 265 |
return materialized |
| 266 |
|
| 267 |
|
| 268 |
def _find_module_tensor(safetensors_data: Any, suffix: str) -> Any: |
| 269 |
"""Return the safetensors entry whose key ends with `suffix`, or None. |
| 270 |
|
| 271 |
PEFT prefixes the key with `base_model.model.` or |
| 272 |
`base_model.model.model.` depending on whether the base model has |
| 273 |
a `model` submodule; matching on suffix sidesteps that variation. |
| 274 |
""" |
| 275 |
for key, tensor in safetensors_data.items(): |
| 276 |
if key.endswith(suffix): |
| 277 |
return tensor |
| 278 |
return None |
| 279 |
|
| 280 |
|
| 281 |
def _as_row_list(tensor: Any) -> list[bytes]: |
| 282 |
"""Turn a numpy-like tensor into a per-row `bytes` list. |
| 283 |
|
| 284 |
The PEFT-saved embedding is (vocab_size, hidden). We slice row i |
| 285 |
to bytes of that single row's elements — then sha256 compares |
| 286 |
row-by-row without materializing the whole matrix twice. |
| 287 |
|
| 288 |
We don't convert dtypes here; the comparison is byte-level on both |
| 289 |
sides, and `convert_hf_to_gguf.py` writes the embedding in the |
| 290 |
tensor's native dtype (F16 → F16, F32 → F32). If the adapter was |
| 291 |
saved in BF16 but the base GGUF ended up as F16, the bytes differ |
| 292 |
even for mathematically-equal values — that's a real mismatch we |
| 293 |
want to flag (silent dtype conversions break inference). |
| 294 |
""" |
| 295 |
import numpy as np |
| 296 |
|
| 297 |
arr = np.asarray(tensor) |
| 298 |
if arr.ndim < 2: |
| 299 |
# embed_tokens / lm_head are always 2D; a 1D tensor means a |
| 300 |
# shape mismatch. Return empty to cause the caller's bounds |
| 301 |
# check to surface the problem. |
| 302 |
return [] |
| 303 |
# Slice ensures contiguity; tobytes on a non-contiguous row would |
| 304 |
# silently re-pack and mask dtype drift. |
| 305 |
rows: list[bytes] = [] |
| 306 |
for i in range(arr.shape[0]): |
| 307 |
row = np.ascontiguousarray(arr[i]) |
| 308 |
rows.append(bytes(row.tobytes())) |
| 309 |
return rows |