| 1 | """Loader for dlm's ``training_state.pt`` (Sprint 25, gradient_ghost). |
| 2 | |
| 3 | The pre-run ``gradient_ghost`` probe diagnoses adapter convergence |
| 4 | without any model load: dlm writes the optimizer state snapshot at |
| 5 | end of training, and we can read its ``global_step`` + per-parameter |
| 6 | Adam ``exp_avg_sq`` magnitudes to answer "did this train long enough?" |
| 7 | in ~50 ms. |
| 8 | |
| 9 | This module is **dlm-aware via file convention only** — we never |
| 10 | ``import dlm``. We just ``torch.load`` a path the user gave us. The |
| 11 | file format is dlm's contract; we follow it conservatively and fail |
| 12 | loudly when the shape drifts. |
| 13 | |
| 14 | ## File-shape reference (verified against dlm 2026-04-20 stores) |
| 15 | |
| 16 | ``` |
| 17 | training_state.pt: |
| 18 | global_step: int |
| 19 | epoch: float |
| 20 | best_val_loss: float |
| 21 | optimizer_state_dict: |
| 22 | state: dict[int param_id → {step, exp_avg, exp_avg_sq}] |
| 23 | param_groups: list[dict] # lr, betas, eps, params: list[int] |
| 24 | scheduler_state_dict: dict |
| 25 | scaler_state_dict: dict | None |
| 26 | *_rng_state: tensor / tuple / None |
| 27 | pinned_versions: dict |
| 28 | base_model_revision: str |
| 29 | dlm_manifest_hash: str | None |
| 30 | use_qlora: bool |
| 31 | ``` |
| 32 | |
| 33 | Per-param ids are integer indices, NOT names — name attribution lives |
| 34 | in ``_param_id_mapping.py``. |
| 35 | """ |
| 36 | |
| 37 | from __future__ import annotations |
| 38 | |
| 39 | import warnings |
| 40 | from dataclasses import dataclass |
| 41 | from pathlib import Path |
| 42 | from typing import Any |
| 43 | |
| 44 | from dlm_sway.core.errors import MissingTrainingStateError, SwayError |
| 45 | |
| 46 | |
| 47 | class TrainingStateError(SwayError): |
| 48 | """Raised when ``training_state.pt`` exists but can't be parsed.""" |
| 49 | |
| 50 | |
| 51 | @dataclass(frozen=True, slots=True) |
| 52 | class ParamStat: |
| 53 | """Optimizer state for a single trainable parameter (one int id). |
| 54 | |
| 55 | The probe reads ``exp_avg_sq_mean`` as a proxy for "how much |
| 56 | gradient variance this parameter has seen recently." Adam's |
| 57 | second-moment estimate is a moving average of squared gradients; |
| 58 | a high value at end-of-training means the gradients haven't |
| 59 | shrunk — typical of an undertrained parameter. |
| 60 | """ |
| 61 | |
| 62 | param_id: int |
| 63 | step: int |
| 64 | exp_avg_norm: float |
| 65 | exp_avg_sq_mean: float |
| 66 | numel: int |
| 67 | |
| 68 | |
| 69 | @dataclass(frozen=True, slots=True) |
| 70 | class TrainingStateSnapshot: |
| 71 | """Everything ``gradient_ghost`` needs from a ``training_state.pt``. |
| 72 | |
| 73 | Constructed by :func:`load_training_state`; consumed by |
| 74 | ``GradientGhostProbe.run``. Frozen so a probe can pass it around |
| 75 | without defensive copies. |
| 76 | """ |
| 77 | |
| 78 | global_step: int |
| 79 | epoch: float |
| 80 | best_val_loss: float |
| 81 | per_param: tuple[ParamStat, ...] |
| 82 | pinned_versions: dict[str, str] |
| 83 | base_model_revision: str | None |
| 84 | dlm_manifest_hash: str | None |
| 85 | use_qlora: bool |
| 86 | |
| 87 | |
| 88 | def load_training_state(adapter_dir: Path) -> TrainingStateSnapshot: |
| 89 | """Load + parse ``adapter_dir/training_state.pt``. |
| 90 | |
| 91 | Parameters |
| 92 | ---------- |
| 93 | adapter_dir: |
| 94 | A dlm adapter version directory (e.g. |
| 95 | ``~/.dlm/store/<id>/adapter/versions/v0001/``). Must contain |
| 96 | a ``training_state.pt`` file. |
| 97 | |
| 98 | Returns |
| 99 | ------- |
| 100 | TrainingStateSnapshot |
| 101 | Frozen snapshot of the fields the probe consumes. |
| 102 | |
| 103 | Raises |
| 104 | ------ |
| 105 | MissingTrainingStateError |
| 106 | ``training_state.pt`` doesn't exist under ``adapter_dir``. The |
| 107 | caller should SKIP (this is a clean signal, not an error). |
| 108 | TrainingStateError |
| 109 | File exists but can't be parsed (torch import missing, |
| 110 | unexpected dict shape, optimizer state missing). The caller |
| 111 | should ERROR — something is structurally wrong. |
| 112 | """ |
| 113 | adapter_dir = Path(adapter_dir) |
| 114 | state_path = adapter_dir / "training_state.pt" |
| 115 | if not state_path.exists(): |
| 116 | raise MissingTrainingStateError(adapter_dir) |
| 117 | |
| 118 | # Lazy-import torch so non-gradient-ghost users don't pay the |
| 119 | # ~700 ms torch-import cost on `import dlm_sway`. |
| 120 | try: |
| 121 | import torch |
| 122 | except ImportError as exc: |
| 123 | raise TrainingStateError( |
| 124 | "torch not installed — gradient_ghost reads pytorch-pickled " |
| 125 | "training_state.pt files. Install with: pip install 'dlm-sway[hf]'" |
| 126 | ) from exc |
| 127 | |
| 128 | # ``weights_only=False`` is required: dlm's training_state.pt |
| 129 | # carries pickled RNG state (numpy / python random). Suppressing |
| 130 | # the ``FutureWarning`` keeps the probe output clean — this is a |
| 131 | # known-trusted artifact dlm produced, not arbitrary user input. |
| 132 | with warnings.catch_warnings(): |
| 133 | warnings.filterwarnings("ignore", category=FutureWarning) |
| 134 | try: |
| 135 | state = torch.load(str(state_path), map_location="cpu", weights_only=False) |
| 136 | except Exception as exc: # noqa: BLE001 — torch.load can raise many shapes |
| 137 | raise TrainingStateError( |
| 138 | f"failed to torch.load {state_path}: {type(exc).__name__}: {exc}" |
| 139 | ) from exc |
| 140 | |
| 141 | if not isinstance(state, dict): |
| 142 | raise TrainingStateError(f"{state_path}: expected dict, got {type(state).__name__}") |
| 143 | opt = state.get("optimizer_state_dict") |
| 144 | if not isinstance(opt, dict): |
| 145 | raise TrainingStateError( |
| 146 | f"{state_path}: missing 'optimizer_state_dict' (got {type(opt).__name__})" |
| 147 | ) |
| 148 | per_param_state = opt.get("state") |
| 149 | if not isinstance(per_param_state, dict): |
| 150 | raise TrainingStateError( |
| 151 | f"{state_path}: optimizer_state_dict.state is not a dict " |
| 152 | f"(got {type(per_param_state).__name__})" |
| 153 | ) |
| 154 | |
| 155 | per_param: list[ParamStat] = [] |
| 156 | for pid, ps in per_param_state.items(): |
| 157 | if not isinstance(pid, int): |
| 158 | raise TrainingStateError( |
| 159 | f"{state_path}: optimizer_state_dict.state has non-int key " |
| 160 | f"{pid!r} (type {type(pid).__name__})" |
| 161 | ) |
| 162 | if not isinstance(ps, dict): |
| 163 | continue # Skip malformed entries silently — same as torch.optim does. |
| 164 | step_v = _scalar_int(ps.get("step", 0)) |
| 165 | exp_avg = ps.get("exp_avg") |
| 166 | exp_avg_sq = ps.get("exp_avg_sq") |
| 167 | per_param.append( |
| 168 | ParamStat( |
| 169 | param_id=pid, |
| 170 | step=step_v, |
| 171 | exp_avg_norm=_tensor_norm(exp_avg), |
| 172 | exp_avg_sq_mean=_tensor_mean(exp_avg_sq), |
| 173 | numel=_tensor_numel(exp_avg_sq), |
| 174 | ) |
| 175 | ) |
| 176 | per_param.sort(key=lambda s: s.param_id) |
| 177 | |
| 178 | return TrainingStateSnapshot( |
| 179 | global_step=int(state.get("global_step", 0) or 0), |
| 180 | epoch=float(state.get("epoch", 0.0) or 0.0), |
| 181 | best_val_loss=float(state.get("best_val_loss", 0.0) or 0.0), |
| 182 | per_param=tuple(per_param), |
| 183 | pinned_versions=dict(state.get("pinned_versions") or {}), |
| 184 | base_model_revision=state.get("base_model_revision"), |
| 185 | dlm_manifest_hash=state.get("dlm_manifest_hash"), |
| 186 | use_qlora=bool(state.get("use_qlora", False)), |
| 187 | ) |
| 188 | |
| 189 | |
| 190 | def _scalar_int(v: Any) -> int: |
| 191 | """torch saves ``step`` as a 0-dim tensor; coerce safely.""" |
| 192 | if v is None: |
| 193 | return 0 |
| 194 | item = getattr(v, "item", None) |
| 195 | if callable(item): |
| 196 | try: |
| 197 | return int(item()) |
| 198 | except Exception: # noqa: BLE001 |
| 199 | return 0 |
| 200 | try: |
| 201 | return int(v) |
| 202 | except Exception: # noqa: BLE001 |
| 203 | return 0 |
| 204 | |
| 205 | |
| 206 | def _tensor_norm(t: Any) -> float: |
| 207 | if t is None: |
| 208 | return 0.0 |
| 209 | norm = getattr(t, "norm", None) |
| 210 | if callable(norm): |
| 211 | try: |
| 212 | return float(norm().item()) |
| 213 | except Exception: # noqa: BLE001 |
| 214 | return 0.0 |
| 215 | return 0.0 |
| 216 | |
| 217 | |
| 218 | def _tensor_mean(t: Any) -> float: |
| 219 | if t is None: |
| 220 | return 0.0 |
| 221 | mean = getattr(t, "mean", None) |
| 222 | if callable(mean): |
| 223 | try: |
| 224 | return float(mean().item()) |
| 225 | except Exception: # noqa: BLE001 |
| 226 | return 0.0 |
| 227 | return 0.0 |
| 228 | |
| 229 | |
| 230 | def _tensor_numel(t: Any) -> int: |
| 231 | if t is None: |
| 232 | return 0 |
| 233 | numel = getattr(t, "numel", None) |
| 234 | if callable(numel): |
| 235 | try: |
| 236 | return int(numel()) |
| 237 | except Exception: # noqa: BLE001 |
| 238 | return 0 |
| 239 | return 0 |