@@ -0,0 +1,239 @@ |
| | 1 | +"""Loader for dlm's ``training_state.pt`` (Sprint 25, gradient_ghost). |
| | 2 | + |
| | 3 | +The pre-run ``gradient_ghost`` probe diagnoses adapter convergence |
| | 4 | +without any model load: dlm writes the optimizer state snapshot at |
| | 5 | +end of training, and we can read its ``global_step`` + per-parameter |
| | 6 | +Adam ``exp_avg_sq`` magnitudes to answer "did this train long enough?" |
| | 7 | +in ~50 ms. |
| | 8 | + |
| | 9 | +This module is **dlm-aware via file convention only** — we never |
| | 10 | +``import dlm``. We just ``torch.load`` a path the user gave us. The |
| | 11 | +file format is dlm's contract; we follow it conservatively and fail |
| | 12 | +loudly when the shape drifts. |
| | 13 | + |
| | 14 | +## File-shape reference (verified against dlm 2026-04-20 stores) |
| | 15 | + |
| | 16 | +``` |
| | 17 | +training_state.pt: |
| | 18 | + global_step: int |
| | 19 | + epoch: float |
| | 20 | + best_val_loss: float |
| | 21 | + optimizer_state_dict: |
| | 22 | + state: dict[int param_id → {step, exp_avg, exp_avg_sq}] |
| | 23 | + param_groups: list[dict] # lr, betas, eps, params: list[int] |
| | 24 | + scheduler_state_dict: dict |
| | 25 | + scaler_state_dict: dict | None |
| | 26 | + *_rng_state: tensor / tuple / None |
| | 27 | + pinned_versions: dict |
| | 28 | + base_model_revision: str |
| | 29 | + dlm_manifest_hash: str | None |
| | 30 | + use_qlora: bool |
| | 31 | +``` |
| | 32 | + |
| | 33 | +Per-param ids are integer indices, NOT names — name attribution lives |
| | 34 | +in ``_param_id_mapping.py``. |
| | 35 | +""" |
| | 36 | + |
| | 37 | +from __future__ import annotations |
| | 38 | + |
| | 39 | +import warnings |
| | 40 | +from dataclasses import dataclass |
| | 41 | +from pathlib import Path |
| | 42 | +from typing import Any |
| | 43 | + |
| | 44 | +from dlm_sway.core.errors import MissingTrainingStateError, SwayError |
| | 45 | + |
| | 46 | + |
| | 47 | +class TrainingStateError(SwayError): |
| | 48 | + """Raised when ``training_state.pt`` exists but can't be parsed.""" |
| | 49 | + |
| | 50 | + |
| | 51 | +@dataclass(frozen=True, slots=True) |
| | 52 | +class ParamStat: |
| | 53 | + """Optimizer state for a single trainable parameter (one int id). |
| | 54 | + |
| | 55 | + The probe reads ``exp_avg_sq_mean`` as a proxy for "how much |
| | 56 | + gradient variance this parameter has seen recently." Adam's |
| | 57 | + second-moment estimate is a moving average of squared gradients; |
| | 58 | + a high value at end-of-training means the gradients haven't |
| | 59 | + shrunk — typical of an undertrained parameter. |
| | 60 | + """ |
| | 61 | + |
| | 62 | + param_id: int |
| | 63 | + step: int |
| | 64 | + exp_avg_norm: float |
| | 65 | + exp_avg_sq_mean: float |
| | 66 | + numel: int |
| | 67 | + |
| | 68 | + |
| | 69 | +@dataclass(frozen=True, slots=True) |
| | 70 | +class TrainingStateSnapshot: |
| | 71 | + """Everything ``gradient_ghost`` needs from a ``training_state.pt``. |
| | 72 | + |
| | 73 | + Constructed by :func:`load_training_state`; consumed by |
| | 74 | + ``GradientGhostProbe.run``. Frozen so a probe can pass it around |
| | 75 | + without defensive copies. |
| | 76 | + """ |
| | 77 | + |
| | 78 | + global_step: int |
| | 79 | + epoch: float |
| | 80 | + best_val_loss: float |
| | 81 | + per_param: tuple[ParamStat, ...] |
| | 82 | + pinned_versions: dict[str, str] |
| | 83 | + base_model_revision: str | None |
| | 84 | + dlm_manifest_hash: str | None |
| | 85 | + use_qlora: bool |
| | 86 | + |
| | 87 | + |
| | 88 | +def load_training_state(adapter_dir: Path) -> TrainingStateSnapshot: |
| | 89 | + """Load + parse ``adapter_dir/training_state.pt``. |
| | 90 | + |
| | 91 | + Parameters |
| | 92 | + ---------- |
| | 93 | + adapter_dir: |
| | 94 | + A dlm adapter version directory (e.g. |
| | 95 | + ``~/.dlm/store/<id>/adapter/versions/v0001/``). Must contain |
| | 96 | + a ``training_state.pt`` file. |
| | 97 | + |
| | 98 | + Returns |
| | 99 | + ------- |
| | 100 | + TrainingStateSnapshot |
| | 101 | + Frozen snapshot of the fields the probe consumes. |
| | 102 | + |
| | 103 | + Raises |
| | 104 | + ------ |
| | 105 | + MissingTrainingStateError |
| | 106 | + ``training_state.pt`` doesn't exist under ``adapter_dir``. The |
| | 107 | + caller should SKIP (this is a clean signal, not an error). |
| | 108 | + TrainingStateError |
| | 109 | + File exists but can't be parsed (torch import missing, |
| | 110 | + unexpected dict shape, optimizer state missing). The caller |
| | 111 | + should ERROR — something is structurally wrong. |
| | 112 | + """ |
| | 113 | + adapter_dir = Path(adapter_dir) |
| | 114 | + state_path = adapter_dir / "training_state.pt" |
| | 115 | + if not state_path.exists(): |
| | 116 | + raise MissingTrainingStateError(adapter_dir) |
| | 117 | + |
| | 118 | + # Lazy-import torch so non-gradient-ghost users don't pay the |
| | 119 | + # ~700 ms torch-import cost on `import dlm_sway`. |
| | 120 | + try: |
| | 121 | + import torch |
| | 122 | + except ImportError as exc: |
| | 123 | + raise TrainingStateError( |
| | 124 | + "torch not installed — gradient_ghost reads pytorch-pickled " |
| | 125 | + "training_state.pt files. Install with: pip install 'dlm-sway[hf]'" |
| | 126 | + ) from exc |
| | 127 | + |
| | 128 | + # ``weights_only=False`` is required: dlm's training_state.pt |
| | 129 | + # carries pickled RNG state (numpy / python random). Suppressing |
| | 130 | + # the ``FutureWarning`` keeps the probe output clean — this is a |
| | 131 | + # known-trusted artifact dlm produced, not arbitrary user input. |
| | 132 | + with warnings.catch_warnings(): |
| | 133 | + warnings.filterwarnings("ignore", category=FutureWarning) |
| | 134 | + try: |
| | 135 | + state = torch.load(str(state_path), map_location="cpu", weights_only=False) |
| | 136 | + except Exception as exc: # noqa: BLE001 — torch.load can raise many shapes |
| | 137 | + raise TrainingStateError( |
| | 138 | + f"failed to torch.load {state_path}: {type(exc).__name__}: {exc}" |
| | 139 | + ) from exc |
| | 140 | + |
| | 141 | + if not isinstance(state, dict): |
| | 142 | + raise TrainingStateError(f"{state_path}: expected dict, got {type(state).__name__}") |
| | 143 | + opt = state.get("optimizer_state_dict") |
| | 144 | + if not isinstance(opt, dict): |
| | 145 | + raise TrainingStateError( |
| | 146 | + f"{state_path}: missing 'optimizer_state_dict' (got {type(opt).__name__})" |
| | 147 | + ) |
| | 148 | + per_param_state = opt.get("state") |
| | 149 | + if not isinstance(per_param_state, dict): |
| | 150 | + raise TrainingStateError( |
| | 151 | + f"{state_path}: optimizer_state_dict.state is not a dict " |
| | 152 | + f"(got {type(per_param_state).__name__})" |
| | 153 | + ) |
| | 154 | + |
| | 155 | + per_param: list[ParamStat] = [] |
| | 156 | + for pid, ps in per_param_state.items(): |
| | 157 | + if not isinstance(pid, int): |
| | 158 | + raise TrainingStateError( |
| | 159 | + f"{state_path}: optimizer_state_dict.state has non-int key " |
| | 160 | + f"{pid!r} (type {type(pid).__name__})" |
| | 161 | + ) |
| | 162 | + if not isinstance(ps, dict): |
| | 163 | + continue # Skip malformed entries silently — same as torch.optim does. |
| | 164 | + step_v = _scalar_int(ps.get("step", 0)) |
| | 165 | + exp_avg = ps.get("exp_avg") |
| | 166 | + exp_avg_sq = ps.get("exp_avg_sq") |
| | 167 | + per_param.append( |
| | 168 | + ParamStat( |
| | 169 | + param_id=pid, |
| | 170 | + step=step_v, |
| | 171 | + exp_avg_norm=_tensor_norm(exp_avg), |
| | 172 | + exp_avg_sq_mean=_tensor_mean(exp_avg_sq), |
| | 173 | + numel=_tensor_numel(exp_avg_sq), |
| | 174 | + ) |
| | 175 | + ) |
| | 176 | + per_param.sort(key=lambda s: s.param_id) |
| | 177 | + |
| | 178 | + return TrainingStateSnapshot( |
| | 179 | + global_step=int(state.get("global_step", 0) or 0), |
| | 180 | + epoch=float(state.get("epoch", 0.0) or 0.0), |
| | 181 | + best_val_loss=float(state.get("best_val_loss", 0.0) or 0.0), |
| | 182 | + per_param=tuple(per_param), |
| | 183 | + pinned_versions=dict(state.get("pinned_versions") or {}), |
| | 184 | + base_model_revision=state.get("base_model_revision"), |
| | 185 | + dlm_manifest_hash=state.get("dlm_manifest_hash"), |
| | 186 | + use_qlora=bool(state.get("use_qlora", False)), |
| | 187 | + ) |
| | 188 | + |
| | 189 | + |
| | 190 | +def _scalar_int(v: Any) -> int: |
| | 191 | + """torch saves ``step`` as a 0-dim tensor; coerce safely.""" |
| | 192 | + if v is None: |
| | 193 | + return 0 |
| | 194 | + item = getattr(v, "item", None) |
| | 195 | + if callable(item): |
| | 196 | + try: |
| | 197 | + return int(item()) |
| | 198 | + except Exception: # noqa: BLE001 |
| | 199 | + return 0 |
| | 200 | + try: |
| | 201 | + return int(v) |
| | 202 | + except Exception: # noqa: BLE001 |
| | 203 | + return 0 |
| | 204 | + |
| | 205 | + |
| | 206 | +def _tensor_norm(t: Any) -> float: |
| | 207 | + if t is None: |
| | 208 | + return 0.0 |
| | 209 | + norm = getattr(t, "norm", None) |
| | 210 | + if callable(norm): |
| | 211 | + try: |
| | 212 | + return float(norm().item()) |
| | 213 | + except Exception: # noqa: BLE001 |
| | 214 | + return 0.0 |
| | 215 | + return 0.0 |
| | 216 | + |
| | 217 | + |
| | 218 | +def _tensor_mean(t: Any) -> float: |
| | 219 | + if t is None: |
| | 220 | + return 0.0 |
| | 221 | + mean = getattr(t, "mean", None) |
| | 222 | + if callable(mean): |
| | 223 | + try: |
| | 224 | + return float(mean().item()) |
| | 225 | + except Exception: # noqa: BLE001 |
| | 226 | + return 0.0 |
| | 227 | + return 0.0 |
| | 228 | + |
| | 229 | + |
| | 230 | +def _tensor_numel(t: Any) -> int: |
| | 231 | + if t is None: |
| | 232 | + return 0 |
| | 233 | + numel = getattr(t, "numel", None) |
| | 234 | + if callable(numel): |
| | 235 | + try: |
| | 236 | + return int(numel()) |
| | 237 | + except Exception: # noqa: BLE001 |
| | 238 | + return 0 |
| | 239 | + return 0 |