Python · 4752 bytes Raw Blame History
1 """Merged-GGUF export path — where LoRA deltas fuse into the base.
2
3 CLAUDE.md pitfall #3: `merge_and_unload` on a 4-bit QLoRA base is
4 precision-unsafe. The canonical safety decision now lives in
5 `dlm.export.precision_safety`; this module only hosts the heavy HF
6 merge work plus a tiny pure helper used by property tests.
7
8 The actual HF work (loading the base in fp16, calling `merge_and_unload`,
9 saving to a tmpdir) is pragma'd from unit coverage: it needs a real
10 model and takes minutes on CI. `check_merge_safety` — the pure
11 decision function — is fully covered.
12 """
13
14 from __future__ import annotations
15
16 from pathlib import Path
17 from typing import Any
18
19 from dlm.export.plan import ExportPlan
20
21
22 def check_merge_safety(plan: ExportPlan, *, was_qlora: bool) -> None:
23 """Pure-python truth-table helper — no subprocess, no HF.
24
25 Main export entry points use `dlm.export.precision_safety` so the
26 adapter-metadata probe and the merged-export gate stay together.
27 This wrapper remains for focused unit/property tests over the
28 boolean truth table itself.
29 """
30 plan.assert_merge_safe(was_qlora=was_qlora)
31
32
33 def perform_merge( # pragma: no cover
34 adapter_dir: Path,
35 out_hf_dir: Path,
36 *,
37 was_qlora: bool,
38 cached_base_dir: Path,
39 ) -> None:
40 """Load base + adapter, merge_and_unload, save merged HF dir.
41
42 Pragma'd from unit coverage: instantiating a real HF model is
43 >5s per test and requires a cached checkpoint. Exercised by the
44 slow-marked integration test.
45
46 Never entered without `check_merge_safety()` having passed first —
47 the runner enforces the order. `was_qlora=True` additionally
48 requires the plan's `--dequantize` flag to have been confirmed.
49
50 `cached_base_dir` is the HF snapshot dir produced by
51 `base_models.downloader.download_spec(spec).path`; we pass it in
52 (rather than re-`from_pretrained(spec.hf_id, ...)`) so the merge
53 path reuses the already-verified, sha256-pinned cache and never
54 touches the network at export time.
55 """
56 import torch
57 from peft import PeftModel
58 from transformers import AutoModelForCausalLM, AutoTokenizer
59
60 # Both QLoRA and plain-LoRA adapters merge onto the upstream fp16
61 # base weights. For QLoRA, loading in fp16 (rather than re-running
62 # bnb 4-bit quantization) is the dequantization — the base weights
63 # in the cache are already fp16 upstream and LoRA deltas merge at
64 # native precision. `was_qlora` is kept in the signature for
65 # downstream logging / audit trails.
66 _ = was_qlora
67 torch_dtype = torch.float16
68
69 base = AutoModelForCausalLM.from_pretrained(
70 str(cached_base_dir),
71 torch_dtype=torch_dtype,
72 local_files_only=True,
73 )
74 peft: Any = PeftModel.from_pretrained(base, str(adapter_dir))
75 merged = peft.merge_and_unload()
76
77 out_hf_dir.mkdir(parents=True, exist_ok=True)
78 merged.save_pretrained(str(out_hf_dir))
79
80 tokenizer = AutoTokenizer.from_pretrained(str(adapter_dir), local_files_only=True)
81 tokenizer.save_pretrained(str(out_hf_dir))
82
83
84 def perform_vl_merge( # pragma: no cover
85 adapter_dir: Path,
86 out_hf_dir: Path,
87 *,
88 cached_base_dir: Path,
89 ) -> None:
90 """VL-aware merge: `AutoModelForImageTextToText` + full processor save.
91
92 Parallel to `perform_merge` but uses the image-text-to-text class so
93 the vision tower travels with the merged output (upstream
94 `convert_hf_to_gguf.py` drops ViT tensors for Qwen2-VL at our
95 pinned tag — the ViT runs through Ollama's preprocessor path — but
96 `from_pretrained` still needs the VL class to reconstruct the full
97 graph before `merge_and_unload`).
98
99 LoRA adapters for VL bases should target language-model projections
100 only (enforced by `preflight.check_vl_target_modules_lm_only`), so
101 `merge_and_unload()` touches LM weights exclusively; vision-tower
102 weights are saved unmodified.
103
104 `processor.save_pretrained` (not just tokenizer) writes the
105 tokenizer + image_processor + processor config together — every
106 piece a recipient needs to re-hydrate.
107 """
108 import torch
109 from peft import PeftModel
110 from transformers import AutoModelForImageTextToText, AutoProcessor
111
112 base = AutoModelForImageTextToText.from_pretrained(
113 str(cached_base_dir),
114 torch_dtype=torch.float16,
115 local_files_only=True,
116 )
117 peft: Any = PeftModel.from_pretrained(base, str(adapter_dir))
118 merged = peft.merge_and_unload()
119
120 out_hf_dir.mkdir(parents=True, exist_ok=True)
121 merged.save_pretrained(str(out_hf_dir))
122
123 processor = AutoProcessor.from_pretrained(str(cached_base_dir), local_files_only=True) # type: ignore[no-untyped-call]
124 processor.save_pretrained(str(out_hf_dir))