Python · 7870 bytes Raw Blame History
1 """Loader for dlm's ``training_state.pt`` (Sprint 25, gradient_ghost).
2
3 The pre-run ``gradient_ghost`` probe diagnoses adapter convergence
4 without any model load: dlm writes the optimizer state snapshot at
5 end of training, and we can read its ``global_step`` + per-parameter
6 Adam ``exp_avg_sq`` magnitudes to answer "did this train long enough?"
7 in ~50 ms.
8
9 This module is **dlm-aware via file convention only** — we never
10 ``import dlm``. We just ``torch.load`` a path the user gave us. The
11 file format is dlm's contract; we follow it conservatively and fail
12 loudly when the shape drifts.
13
14 ## File-shape reference (verified against dlm 2026-04-20 stores)
15
16 ```
17 training_state.pt:
18 global_step: int
19 epoch: float
20 best_val_loss: float
21 optimizer_state_dict:
22 state: dict[int param_id → {step, exp_avg, exp_avg_sq}]
23 param_groups: list[dict] # lr, betas, eps, params: list[int]
24 scheduler_state_dict: dict
25 scaler_state_dict: dict | None
26 *_rng_state: tensor / tuple / None
27 pinned_versions: dict
28 base_model_revision: str
29 dlm_manifest_hash: str | None
30 use_qlora: bool
31 ```
32
33 Per-param ids are integer indices, NOT names — name attribution lives
34 in ``_param_id_mapping.py``.
35 """
36
37 from __future__ import annotations
38
39 import warnings
40 from dataclasses import dataclass
41 from pathlib import Path
42 from typing import Any
43
44 from dlm_sway.core.errors import MissingTrainingStateError, SwayError
45
46
47 class TrainingStateError(SwayError):
48 """Raised when ``training_state.pt`` exists but can't be parsed."""
49
50
51 @dataclass(frozen=True, slots=True)
52 class ParamStat:
53 """Optimizer state for a single trainable parameter (one int id).
54
55 The probe reads ``exp_avg_sq_mean`` as a proxy for "how much
56 gradient variance this parameter has seen recently." Adam's
57 second-moment estimate is a moving average of squared gradients;
58 a high value at end-of-training means the gradients haven't
59 shrunk — typical of an undertrained parameter.
60 """
61
62 param_id: int
63 step: int
64 exp_avg_norm: float
65 exp_avg_sq_mean: float
66 numel: int
67
68
69 @dataclass(frozen=True, slots=True)
70 class TrainingStateSnapshot:
71 """Everything ``gradient_ghost`` needs from a ``training_state.pt``.
72
73 Constructed by :func:`load_training_state`; consumed by
74 ``GradientGhostProbe.run``. Frozen so a probe can pass it around
75 without defensive copies.
76 """
77
78 global_step: int
79 epoch: float
80 best_val_loss: float
81 per_param: tuple[ParamStat, ...]
82 pinned_versions: dict[str, str]
83 base_model_revision: str | None
84 dlm_manifest_hash: str | None
85 use_qlora: bool
86
87
88 def load_training_state(adapter_dir: Path) -> TrainingStateSnapshot:
89 """Load + parse ``adapter_dir/training_state.pt``.
90
91 Parameters
92 ----------
93 adapter_dir:
94 A dlm adapter version directory (e.g.
95 ``~/.dlm/store/<id>/adapter/versions/v0001/``). Must contain
96 a ``training_state.pt`` file.
97
98 Returns
99 -------
100 TrainingStateSnapshot
101 Frozen snapshot of the fields the probe consumes.
102
103 Raises
104 ------
105 MissingTrainingStateError
106 ``training_state.pt`` doesn't exist under ``adapter_dir``. The
107 caller should SKIP (this is a clean signal, not an error).
108 TrainingStateError
109 File exists but can't be parsed (torch import missing,
110 unexpected dict shape, optimizer state missing). The caller
111 should ERROR — something is structurally wrong.
112 """
113 adapter_dir = Path(adapter_dir)
114 state_path = adapter_dir / "training_state.pt"
115 if not state_path.exists():
116 raise MissingTrainingStateError(adapter_dir)
117
118 # Lazy-import torch so non-gradient-ghost users don't pay the
119 # ~700 ms torch-import cost on `import dlm_sway`.
120 try:
121 import torch
122 except ImportError as exc:
123 raise TrainingStateError(
124 "torch not installed — gradient_ghost reads pytorch-pickled "
125 "training_state.pt files. Install with: pip install 'dlm-sway[hf]'"
126 ) from exc
127
128 # ``weights_only=False`` is required: dlm's training_state.pt
129 # carries pickled RNG state (numpy / python random). Suppressing
130 # the ``FutureWarning`` keeps the probe output clean — this is a
131 # known-trusted artifact dlm produced, not arbitrary user input.
132 with warnings.catch_warnings():
133 warnings.filterwarnings("ignore", category=FutureWarning)
134 try:
135 state = torch.load(str(state_path), map_location="cpu", weights_only=False)
136 except Exception as exc: # noqa: BLE001 — torch.load can raise many shapes
137 raise TrainingStateError(
138 f"failed to torch.load {state_path}: {type(exc).__name__}: {exc}"
139 ) from exc
140
141 if not isinstance(state, dict):
142 raise TrainingStateError(f"{state_path}: expected dict, got {type(state).__name__}")
143 opt = state.get("optimizer_state_dict")
144 if not isinstance(opt, dict):
145 raise TrainingStateError(
146 f"{state_path}: missing 'optimizer_state_dict' (got {type(opt).__name__})"
147 )
148 per_param_state = opt.get("state")
149 if not isinstance(per_param_state, dict):
150 raise TrainingStateError(
151 f"{state_path}: optimizer_state_dict.state is not a dict "
152 f"(got {type(per_param_state).__name__})"
153 )
154
155 per_param: list[ParamStat] = []
156 for pid, ps in per_param_state.items():
157 if not isinstance(pid, int):
158 raise TrainingStateError(
159 f"{state_path}: optimizer_state_dict.state has non-int key "
160 f"{pid!r} (type {type(pid).__name__})"
161 )
162 if not isinstance(ps, dict):
163 continue # Skip malformed entries silently — same as torch.optim does.
164 step_v = _scalar_int(ps.get("step", 0))
165 exp_avg = ps.get("exp_avg")
166 exp_avg_sq = ps.get("exp_avg_sq")
167 per_param.append(
168 ParamStat(
169 param_id=pid,
170 step=step_v,
171 exp_avg_norm=_tensor_norm(exp_avg),
172 exp_avg_sq_mean=_tensor_mean(exp_avg_sq),
173 numel=_tensor_numel(exp_avg_sq),
174 )
175 )
176 per_param.sort(key=lambda s: s.param_id)
177
178 return TrainingStateSnapshot(
179 global_step=int(state.get("global_step", 0) or 0),
180 epoch=float(state.get("epoch", 0.0) or 0.0),
181 best_val_loss=float(state.get("best_val_loss", 0.0) or 0.0),
182 per_param=tuple(per_param),
183 pinned_versions=dict(state.get("pinned_versions") or {}),
184 base_model_revision=state.get("base_model_revision"),
185 dlm_manifest_hash=state.get("dlm_manifest_hash"),
186 use_qlora=bool(state.get("use_qlora", False)),
187 )
188
189
190 def _scalar_int(v: Any) -> int:
191 """torch saves ``step`` as a 0-dim tensor; coerce safely."""
192 if v is None:
193 return 0
194 item = getattr(v, "item", None)
195 if callable(item):
196 try:
197 return int(item())
198 except Exception: # noqa: BLE001
199 return 0
200 try:
201 return int(v)
202 except Exception: # noqa: BLE001
203 return 0
204
205
206 def _tensor_norm(t: Any) -> float:
207 if t is None:
208 return 0.0
209 norm = getattr(t, "norm", None)
210 if callable(norm):
211 try:
212 return float(norm().item())
213 except Exception: # noqa: BLE001
214 return 0.0
215 return 0.0
216
217
218 def _tensor_mean(t: Any) -> float:
219 if t is None:
220 return 0.0
221 mean = getattr(t, "mean", None)
222 if callable(mean):
223 try:
224 return float(mean().item())
225 except Exception: # noqa: BLE001
226 return 0.0
227 return 0.0
228
229
230 def _tensor_numel(t: Any) -> int:
231 if t is None:
232 return 0
233 numel = getattr(t, "numel", None)
234 if callable(numel):
235 try:
236 return int(numel())
237 except Exception: # noqa: BLE001
238 return 0
239 return 0