Python · 6257 bytes Raw Blame History
1 """Adapter loader for inference (`dlm prompt`).
2
3 Given a `StorePath` and the current host's `Capabilities`, resolve an
4 `InferencePlan` and load the PEFT model + tokenizer ready for
5 `generate()`. Two paths:
6
7 - **4-bit QLoRA path** (CUDA + bnb installed + adapter was QLoRA-trained):
8 `AutoModelForCausalLM.from_pretrained(..., quantization_config=bnb)`
9 then `PeftModel.from_pretrained(base, adapter_dir)`.
10 - **fp16 / bf16 path** (everything else, including the F05 "CUDA-saved
11 QLoRA resumed on Apple Silicon" case): `AutoModelForCausalLM` at the
12 plan's `precision`, then adapter load. Dequantization for a 4-bit-
13 trained adapter loaded without bnb happens implicitly: the saved
14 LoRA delta weights are already in fp16; loading the BASE at fp16
15 (not 4-bit) is the correct behavior. The adapter adds a small
16 fp16 residual on top of a fp16 base.
17
18 The tokenizer is loaded from the **adapter directory**, not the
19 `store.cache/`, because tokenizer bringup persists the final
20 tokenizer state (including `<|pad|>` additions) into the adapter dir
21 at training-end. This is the contract export and inference depend on.
22
23 Heavy imports are deferred; the orchestration logic that picks args,
24 paths, and dtypes is unit-testable without HF.
25 """
26
27 from __future__ import annotations
28
29 from dataclasses import dataclass
30 from pathlib import Path
31 from typing import TYPE_CHECKING, Any
32
33 from dlm.inference.errors import AdapterNotFoundError
34 from dlm.inference.plan import InferencePlan
35
36 if TYPE_CHECKING:
37 from dlm.base_models import BaseModelSpec
38 from dlm.store.paths import StorePath
39
40
41 @dataclass(frozen=True)
42 class LoadedInference:
43 """Result of `load_for_inference`."""
44
45 model: Any # PeftModel — Any to avoid pulling peft into type stubs
46 tokenizer: Any
47 plan: InferencePlan
48 adapter_path: Path
49
50
51 def build_load_kwargs(
52 spec: BaseModelSpec,
53 plan: InferencePlan,
54 *,
55 has_bitsandbytes: bool,
56 ) -> dict[str, Any]:
57 """Assemble `AutoModelForCausalLM.from_pretrained` kwargs for `plan`.
58
59 Extracted so unit tests can verify the config-assembly logic
60 without actually loading a model. The real loader calls this plus
61 the HF API; this function returns the dict, nothing more.
62
63 - QLoRA path: `quantization_config=BitsAndBytesConfig(load_in_4bit=True, ...)`.
64 - Dequantize path: plain `torch_dtype=...`; no quantization config.
65 - Plain LoRA / fp: `torch_dtype=...`.
66 """
67 kwargs: dict[str, Any] = {
68 "revision": spec.revision,
69 "attn_implementation": plan.attn_implementation,
70 }
71
72 if not plan.dequantize_on_load and has_bitsandbytes and plan.precision in ("bf16", "fp16"):
73 # Only reach here on the real 4-bit CUDA+bnb path.
74 from transformers import BitsAndBytesConfig # pragma: no cover
75
76 compute_dtype = _torch_dtype_for(plan.precision) # pragma: no cover
77 kwargs["quantization_config"] = BitsAndBytesConfig( # type: ignore[no-untyped-call] # pragma: no cover
78 load_in_4bit=True,
79 bnb_4bit_quant_type="nf4",
80 bnb_4bit_compute_dtype=compute_dtype,
81 bnb_4bit_use_double_quant=True,
82 )
83 else:
84 kwargs["torch_dtype"] = _torch_dtype_for(plan.precision)
85
86 return kwargs
87
88
89 def _torch_dtype_for(precision: str) -> Any:
90 """Map precision string to `torch.dtype`.
91
92 Isolated so unit tests can call `build_load_kwargs` with a string
93 result (they assert the key shape, not the exact dtype object) while
94 the real path still gets a torch.dtype.
95 """
96 try:
97 import torch
98 except ImportError: # pragma: no cover
99 return precision
100
101 lookup = {
102 "bf16": torch.bfloat16,
103 "fp16": torch.float16,
104 }
105 return lookup.get(precision, torch.float16)
106
107
108 def resolve_adapter_path(store: StorePath, *, adapter_name: str | None) -> Path:
109 """Return the on-disk adapter version dir for inference.
110
111 Single entry point for both the flat (unnamed) and named-adapter
112 layouts. Raises `AdapterNotFoundError` with a path-appropriate
113 hint when `current.txt` is missing or empty — the most common
114 "haven't trained yet" failure mode.
115 """
116 if adapter_name is None:
117 adapter_path = store.resolve_current_adapter()
118 pointer = store.adapter_current_pointer
119 else:
120 adapter_path = store.resolve_current_adapter_for(adapter_name)
121 pointer = store.adapter_current_pointer_for(adapter_name)
122 if adapter_path is None or not adapter_path.exists():
123 hint = (
124 f"no adapter under {pointer}; "
125 f"has `dlm train` run successfully"
126 f"{f' for adapter {adapter_name!r}' if adapter_name else ''}?"
127 )
128 raise AdapterNotFoundError(hint)
129 return adapter_path
130
131
132 def load_for_inference( # pragma: no cover
133 store: StorePath,
134 spec: BaseModelSpec,
135 caps: Any,
136 *,
137 adapter_name: str | None = None,
138 ) -> LoadedInference:
139 """Resolve plan + load base + adapter + tokenizer.
140
141 Pragma'd from unit coverage because it calls `AutoModelForCausalLM.from_pretrained`
142 and `PeftModel.from_pretrained`, which each need ~5 seconds and a
143 real HF cache. Covered by the slow-marked integration test.
144
145 `adapter_name`, when provided, targets the named multi-adapter
146 layout (`adapter/<name>/current.txt`). When `None`, uses the flat
147 single-adapter layout.
148 """
149 adapter_path = resolve_adapter_path(store, adapter_name=adapter_name)
150
151 from transformers import AutoModelForCausalLM, AutoTokenizer
152
153 from dlm.inference.plan import resolve_inference
154
155 plan = resolve_inference(adapter_path, caps)
156 has_bnb = bool(getattr(caps, "has_bitsandbytes", False))
157 kwargs = build_load_kwargs(spec, plan, has_bitsandbytes=has_bnb)
158
159 base = AutoModelForCausalLM.from_pretrained(spec.hf_id, **kwargs)
160
161 from peft import PeftModel
162
163 model = PeftModel.from_pretrained(base, str(adapter_path))
164 model.eval()
165
166 # Tokenizer from the adapter dir — source of truth after any
167 # vocab growth from training-time bringup.
168 tokenizer = AutoTokenizer.from_pretrained(str(adapter_path))
169
170 return LoadedInference(
171 model=model,
172 tokenizer=tokenizer,
173 plan=plan,
174 adapter_path=adapter_path,
175 )