@@ -0,0 +1,239 @@ |
| 1 | +"""Loader for dlm's ``training_state.pt`` (Sprint 25, gradient_ghost). |
| 2 | + |
| 3 | +The pre-run ``gradient_ghost`` probe diagnoses adapter convergence |
| 4 | +without any model load: dlm writes the optimizer state snapshot at |
| 5 | +end of training, and we can read its ``global_step`` + per-parameter |
| 6 | +Adam ``exp_avg_sq`` magnitudes to answer "did this train long enough?" |
| 7 | +in ~50 ms. |
| 8 | + |
| 9 | +This module is **dlm-aware via file convention only** — we never |
| 10 | +``import dlm``. We just ``torch.load`` a path the user gave us. The |
| 11 | +file format is dlm's contract; we follow it conservatively and fail |
| 12 | +loudly when the shape drifts. |
| 13 | + |
| 14 | +## File-shape reference (verified against dlm 2026-04-20 stores) |
| 15 | + |
| 16 | +``` |
| 17 | +training_state.pt: |
| 18 | + global_step: int |
| 19 | + epoch: float |
| 20 | + best_val_loss: float |
| 21 | + optimizer_state_dict: |
| 22 | + state: dict[int param_id → {step, exp_avg, exp_avg_sq}] |
| 23 | + param_groups: list[dict] # lr, betas, eps, params: list[int] |
| 24 | + scheduler_state_dict: dict |
| 25 | + scaler_state_dict: dict | None |
| 26 | + *_rng_state: tensor / tuple / None |
| 27 | + pinned_versions: dict |
| 28 | + base_model_revision: str |
| 29 | + dlm_manifest_hash: str | None |
| 30 | + use_qlora: bool |
| 31 | +``` |
| 32 | + |
| 33 | +Per-param ids are integer indices, NOT names — name attribution lives |
| 34 | +in ``_param_id_mapping.py``. |
| 35 | +""" |
| 36 | + |
| 37 | +from __future__ import annotations |
| 38 | + |
| 39 | +import warnings |
| 40 | +from dataclasses import dataclass |
| 41 | +from pathlib import Path |
| 42 | +from typing import Any |
| 43 | + |
| 44 | +from dlm_sway.core.errors import MissingTrainingStateError, SwayError |
| 45 | + |
| 46 | + |
| 47 | +class TrainingStateError(SwayError): |
| 48 | + """Raised when ``training_state.pt`` exists but can't be parsed.""" |
| 49 | + |
| 50 | + |
| 51 | +@dataclass(frozen=True, slots=True) |
| 52 | +class ParamStat: |
| 53 | + """Optimizer state for a single trainable parameter (one int id). |
| 54 | + |
| 55 | + The probe reads ``exp_avg_sq_mean`` as a proxy for "how much |
| 56 | + gradient variance this parameter has seen recently." Adam's |
| 57 | + second-moment estimate is a moving average of squared gradients; |
| 58 | + a high value at end-of-training means the gradients haven't |
| 59 | + shrunk — typical of an undertrained parameter. |
| 60 | + """ |
| 61 | + |
| 62 | + param_id: int |
| 63 | + step: int |
| 64 | + exp_avg_norm: float |
| 65 | + exp_avg_sq_mean: float |
| 66 | + numel: int |
| 67 | + |
| 68 | + |
| 69 | +@dataclass(frozen=True, slots=True) |
| 70 | +class TrainingStateSnapshot: |
| 71 | + """Everything ``gradient_ghost`` needs from a ``training_state.pt``. |
| 72 | + |
| 73 | + Constructed by :func:`load_training_state`; consumed by |
| 74 | + ``GradientGhostProbe.run``. Frozen so a probe can pass it around |
| 75 | + without defensive copies. |
| 76 | + """ |
| 77 | + |
| 78 | + global_step: int |
| 79 | + epoch: float |
| 80 | + best_val_loss: float |
| 81 | + per_param: tuple[ParamStat, ...] |
| 82 | + pinned_versions: dict[str, str] |
| 83 | + base_model_revision: str | None |
| 84 | + dlm_manifest_hash: str | None |
| 85 | + use_qlora: bool |
| 86 | + |
| 87 | + |
| 88 | +def load_training_state(adapter_dir: Path) -> TrainingStateSnapshot: |
| 89 | + """Load + parse ``adapter_dir/training_state.pt``. |
| 90 | + |
| 91 | + Parameters |
| 92 | + ---------- |
| 93 | + adapter_dir: |
| 94 | + A dlm adapter version directory (e.g. |
| 95 | + ``~/.dlm/store/<id>/adapter/versions/v0001/``). Must contain |
| 96 | + a ``training_state.pt`` file. |
| 97 | + |
| 98 | + Returns |
| 99 | + ------- |
| 100 | + TrainingStateSnapshot |
| 101 | + Frozen snapshot of the fields the probe consumes. |
| 102 | + |
| 103 | + Raises |
| 104 | + ------ |
| 105 | + MissingTrainingStateError |
| 106 | + ``training_state.pt`` doesn't exist under ``adapter_dir``. The |
| 107 | + caller should SKIP (this is a clean signal, not an error). |
| 108 | + TrainingStateError |
| 109 | + File exists but can't be parsed (torch import missing, |
| 110 | + unexpected dict shape, optimizer state missing). The caller |
| 111 | + should ERROR — something is structurally wrong. |
| 112 | + """ |
| 113 | + adapter_dir = Path(adapter_dir) |
| 114 | + state_path = adapter_dir / "training_state.pt" |
| 115 | + if not state_path.exists(): |
| 116 | + raise MissingTrainingStateError(adapter_dir) |
| 117 | + |
| 118 | + # Lazy-import torch so non-gradient-ghost users don't pay the |
| 119 | + # ~700 ms torch-import cost on `import dlm_sway`. |
| 120 | + try: |
| 121 | + import torch |
| 122 | + except ImportError as exc: |
| 123 | + raise TrainingStateError( |
| 124 | + "torch not installed — gradient_ghost reads pytorch-pickled " |
| 125 | + "training_state.pt files. Install with: pip install 'dlm-sway[hf]'" |
| 126 | + ) from exc |
| 127 | + |
| 128 | + # ``weights_only=False`` is required: dlm's training_state.pt |
| 129 | + # carries pickled RNG state (numpy / python random). Suppressing |
| 130 | + # the ``FutureWarning`` keeps the probe output clean — this is a |
| 131 | + # known-trusted artifact dlm produced, not arbitrary user input. |
| 132 | + with warnings.catch_warnings(): |
| 133 | + warnings.filterwarnings("ignore", category=FutureWarning) |
| 134 | + try: |
| 135 | + state = torch.load(str(state_path), map_location="cpu", weights_only=False) |
| 136 | + except Exception as exc: # noqa: BLE001 — torch.load can raise many shapes |
| 137 | + raise TrainingStateError( |
| 138 | + f"failed to torch.load {state_path}: {type(exc).__name__}: {exc}" |
| 139 | + ) from exc |
| 140 | + |
| 141 | + if not isinstance(state, dict): |
| 142 | + raise TrainingStateError(f"{state_path}: expected dict, got {type(state).__name__}") |
| 143 | + opt = state.get("optimizer_state_dict") |
| 144 | + if not isinstance(opt, dict): |
| 145 | + raise TrainingStateError( |
| 146 | + f"{state_path}: missing 'optimizer_state_dict' (got {type(opt).__name__})" |
| 147 | + ) |
| 148 | + per_param_state = opt.get("state") |
| 149 | + if not isinstance(per_param_state, dict): |
| 150 | + raise TrainingStateError( |
| 151 | + f"{state_path}: optimizer_state_dict.state is not a dict " |
| 152 | + f"(got {type(per_param_state).__name__})" |
| 153 | + ) |
| 154 | + |
| 155 | + per_param: list[ParamStat] = [] |
| 156 | + for pid, ps in per_param_state.items(): |
| 157 | + if not isinstance(pid, int): |
| 158 | + raise TrainingStateError( |
| 159 | + f"{state_path}: optimizer_state_dict.state has non-int key " |
| 160 | + f"{pid!r} (type {type(pid).__name__})" |
| 161 | + ) |
| 162 | + if not isinstance(ps, dict): |
| 163 | + continue # Skip malformed entries silently — same as torch.optim does. |
| 164 | + step_v = _scalar_int(ps.get("step", 0)) |
| 165 | + exp_avg = ps.get("exp_avg") |
| 166 | + exp_avg_sq = ps.get("exp_avg_sq") |
| 167 | + per_param.append( |
| 168 | + ParamStat( |
| 169 | + param_id=pid, |
| 170 | + step=step_v, |
| 171 | + exp_avg_norm=_tensor_norm(exp_avg), |
| 172 | + exp_avg_sq_mean=_tensor_mean(exp_avg_sq), |
| 173 | + numel=_tensor_numel(exp_avg_sq), |
| 174 | + ) |
| 175 | + ) |
| 176 | + per_param.sort(key=lambda s: s.param_id) |
| 177 | + |
| 178 | + return TrainingStateSnapshot( |
| 179 | + global_step=int(state.get("global_step", 0) or 0), |
| 180 | + epoch=float(state.get("epoch", 0.0) or 0.0), |
| 181 | + best_val_loss=float(state.get("best_val_loss", 0.0) or 0.0), |
| 182 | + per_param=tuple(per_param), |
| 183 | + pinned_versions=dict(state.get("pinned_versions") or {}), |
| 184 | + base_model_revision=state.get("base_model_revision"), |
| 185 | + dlm_manifest_hash=state.get("dlm_manifest_hash"), |
| 186 | + use_qlora=bool(state.get("use_qlora", False)), |
| 187 | + ) |
| 188 | + |
| 189 | + |
| 190 | +def _scalar_int(v: Any) -> int: |
| 191 | + """torch saves ``step`` as a 0-dim tensor; coerce safely.""" |
| 192 | + if v is None: |
| 193 | + return 0 |
| 194 | + item = getattr(v, "item", None) |
| 195 | + if callable(item): |
| 196 | + try: |
| 197 | + return int(item()) |
| 198 | + except Exception: # noqa: BLE001 |
| 199 | + return 0 |
| 200 | + try: |
| 201 | + return int(v) |
| 202 | + except Exception: # noqa: BLE001 |
| 203 | + return 0 |
| 204 | + |
| 205 | + |
| 206 | +def _tensor_norm(t: Any) -> float: |
| 207 | + if t is None: |
| 208 | + return 0.0 |
| 209 | + norm = getattr(t, "norm", None) |
| 210 | + if callable(norm): |
| 211 | + try: |
| 212 | + return float(norm().item()) |
| 213 | + except Exception: # noqa: BLE001 |
| 214 | + return 0.0 |
| 215 | + return 0.0 |
| 216 | + |
| 217 | + |
| 218 | +def _tensor_mean(t: Any) -> float: |
| 219 | + if t is None: |
| 220 | + return 0.0 |
| 221 | + mean = getattr(t, "mean", None) |
| 222 | + if callable(mean): |
| 223 | + try: |
| 224 | + return float(mean().item()) |
| 225 | + except Exception: # noqa: BLE001 |
| 226 | + return 0.0 |
| 227 | + return 0.0 |
| 228 | + |
| 229 | + |
| 230 | +def _tensor_numel(t: Any) -> int: |
| 231 | + if t is None: |
| 232 | + return 0 |
| 233 | + numel = getattr(t, "numel", None) |
| 234 | + if callable(numel): |
| 235 | + try: |
| 236 | + return int(numel()) |
| 237 | + except Exception: # noqa: BLE001 |
| 238 | + return 0 |
| 239 | + return 0 |