tenseleyflow/sway / f6e20a1

Browse files

probes/_training_state: torch.load wrapper + ParamStat dataclass + MissingTrainingStateError (S25 P1)

Authored by mfwolffe <wolffemf@dukes.jmu.edu>
SHA
f6e20a1d1c8d77a18942d23aae68ff22e35e269f
Parents
e69bca5
Tree
b44465d

2 changed files

StatusFile+-
M src/dlm_sway/core/errors.py 19 0
A src/dlm_sway/probes/_training_state.py 239 0
src/dlm_sway/core/errors.pymodified
@@ -65,6 +65,25 @@ class ProbeError(SwayError):
6565
         self.probe = probe
6666
 
6767
 
68
+class MissingTrainingStateError(SwayError):
69
+    """The pre-run probes (S25 ``gradient_ghost``) couldn't find a
70
+    ``training_state.pt`` next to the adapter.
71
+
72
+    Distinguishes "the file legitimately doesn't exist for this adapter"
73
+    (probe SKIPs cleanly) from "the file exists but won't load"
74
+    (probe ERRORs). Pre-run probes catch this and emit SKIP rather
75
+    than letting the missing file kill the suite.
76
+    """
77
+
78
+    def __init__(self, adapter_path: object) -> None:
79
+        super().__init__(
80
+            f"no training_state.pt under {adapter_path} — adapter wasn't "
81
+            f"produced by dlm or the file was pruned. Pre-run diagnostics "
82
+            f"(gradient_ghost) will SKIP for this adapter."
83
+        )
84
+        self.adapter_path = adapter_path
85
+
86
+
6887
 class DlmCompatError(SwayError):
6988
     """The installed ``dlm`` package's public surface doesn't match what
7089
     sway's resolver expects.
src/dlm_sway/probes/_training_state.pyadded
@@ -0,0 +1,239 @@
1
+"""Loader for dlm's ``training_state.pt`` (Sprint 25, gradient_ghost).
2
+
3
+The pre-run ``gradient_ghost`` probe diagnoses adapter convergence
4
+without any model load: dlm writes the optimizer state snapshot at
5
+end of training, and we can read its ``global_step`` + per-parameter
6
+Adam ``exp_avg_sq`` magnitudes to answer "did this train long enough?"
7
+in ~50 ms.
8
+
9
+This module is **dlm-aware via file convention only** — we never
10
+``import dlm``. We just ``torch.load`` a path the user gave us. The
11
+file format is dlm's contract; we follow it conservatively and fail
12
+loudly when the shape drifts.
13
+
14
+## File-shape reference (verified against dlm 2026-04-20 stores)
15
+
16
+```
17
+training_state.pt:
18
+  global_step: int
19
+  epoch: float
20
+  best_val_loss: float
21
+  optimizer_state_dict:
22
+    state: dict[int param_id → {step, exp_avg, exp_avg_sq}]
23
+    param_groups: list[dict]   # lr, betas, eps, params: list[int]
24
+  scheduler_state_dict: dict
25
+  scaler_state_dict: dict | None
26
+  *_rng_state: tensor / tuple / None
27
+  pinned_versions: dict
28
+  base_model_revision: str
29
+  dlm_manifest_hash: str | None
30
+  use_qlora: bool
31
+```
32
+
33
+Per-param ids are integer indices, NOT names — name attribution lives
34
+in ``_param_id_mapping.py``.
35
+"""
36
+
37
+from __future__ import annotations
38
+
39
+import warnings
40
+from dataclasses import dataclass
41
+from pathlib import Path
42
+from typing import Any
43
+
44
+from dlm_sway.core.errors import MissingTrainingStateError, SwayError
45
+
46
+
47
+class TrainingStateError(SwayError):
48
+    """Raised when ``training_state.pt`` exists but can't be parsed."""
49
+
50
+
51
+@dataclass(frozen=True, slots=True)
52
+class ParamStat:
53
+    """Optimizer state for a single trainable parameter (one int id).
54
+
55
+    The probe reads ``exp_avg_sq_mean`` as a proxy for "how much
56
+    gradient variance this parameter has seen recently." Adam's
57
+    second-moment estimate is a moving average of squared gradients;
58
+    a high value at end-of-training means the gradients haven't
59
+    shrunk — typical of an undertrained parameter.
60
+    """
61
+
62
+    param_id: int
63
+    step: int
64
+    exp_avg_norm: float
65
+    exp_avg_sq_mean: float
66
+    numel: int
67
+
68
+
69
+@dataclass(frozen=True, slots=True)
70
+class TrainingStateSnapshot:
71
+    """Everything ``gradient_ghost`` needs from a ``training_state.pt``.
72
+
73
+    Constructed by :func:`load_training_state`; consumed by
74
+    ``GradientGhostProbe.run``. Frozen so a probe can pass it around
75
+    without defensive copies.
76
+    """
77
+
78
+    global_step: int
79
+    epoch: float
80
+    best_val_loss: float
81
+    per_param: tuple[ParamStat, ...]
82
+    pinned_versions: dict[str, str]
83
+    base_model_revision: str | None
84
+    dlm_manifest_hash: str | None
85
+    use_qlora: bool
86
+
87
+
88
+def load_training_state(adapter_dir: Path) -> TrainingStateSnapshot:
89
+    """Load + parse ``adapter_dir/training_state.pt``.
90
+
91
+    Parameters
92
+    ----------
93
+    adapter_dir:
94
+        A dlm adapter version directory (e.g.
95
+        ``~/.dlm/store/<id>/adapter/versions/v0001/``). Must contain
96
+        a ``training_state.pt`` file.
97
+
98
+    Returns
99
+    -------
100
+    TrainingStateSnapshot
101
+        Frozen snapshot of the fields the probe consumes.
102
+
103
+    Raises
104
+    ------
105
+    MissingTrainingStateError
106
+        ``training_state.pt`` doesn't exist under ``adapter_dir``. The
107
+        caller should SKIP (this is a clean signal, not an error).
108
+    TrainingStateError
109
+        File exists but can't be parsed (torch import missing,
110
+        unexpected dict shape, optimizer state missing). The caller
111
+        should ERROR — something is structurally wrong.
112
+    """
113
+    adapter_dir = Path(adapter_dir)
114
+    state_path = adapter_dir / "training_state.pt"
115
+    if not state_path.exists():
116
+        raise MissingTrainingStateError(adapter_dir)
117
+
118
+    # Lazy-import torch so non-gradient-ghost users don't pay the
119
+    # ~700 ms torch-import cost on `import dlm_sway`.
120
+    try:
121
+        import torch
122
+    except ImportError as exc:
123
+        raise TrainingStateError(
124
+            "torch not installed — gradient_ghost reads pytorch-pickled "
125
+            "training_state.pt files. Install with: pip install 'dlm-sway[hf]'"
126
+        ) from exc
127
+
128
+    # ``weights_only=False`` is required: dlm's training_state.pt
129
+    # carries pickled RNG state (numpy / python random). Suppressing
130
+    # the ``FutureWarning`` keeps the probe output clean — this is a
131
+    # known-trusted artifact dlm produced, not arbitrary user input.
132
+    with warnings.catch_warnings():
133
+        warnings.filterwarnings("ignore", category=FutureWarning)
134
+        try:
135
+            state = torch.load(str(state_path), map_location="cpu", weights_only=False)
136
+        except Exception as exc:  # noqa: BLE001 — torch.load can raise many shapes
137
+            raise TrainingStateError(
138
+                f"failed to torch.load {state_path}: {type(exc).__name__}: {exc}"
139
+            ) from exc
140
+
141
+    if not isinstance(state, dict):
142
+        raise TrainingStateError(f"{state_path}: expected dict, got {type(state).__name__}")
143
+    opt = state.get("optimizer_state_dict")
144
+    if not isinstance(opt, dict):
145
+        raise TrainingStateError(
146
+            f"{state_path}: missing 'optimizer_state_dict' (got {type(opt).__name__})"
147
+        )
148
+    per_param_state = opt.get("state")
149
+    if not isinstance(per_param_state, dict):
150
+        raise TrainingStateError(
151
+            f"{state_path}: optimizer_state_dict.state is not a dict "
152
+            f"(got {type(per_param_state).__name__})"
153
+        )
154
+
155
+    per_param: list[ParamStat] = []
156
+    for pid, ps in per_param_state.items():
157
+        if not isinstance(pid, int):
158
+            raise TrainingStateError(
159
+                f"{state_path}: optimizer_state_dict.state has non-int key "
160
+                f"{pid!r} (type {type(pid).__name__})"
161
+            )
162
+        if not isinstance(ps, dict):
163
+            continue  # Skip malformed entries silently — same as torch.optim does.
164
+        step_v = _scalar_int(ps.get("step", 0))
165
+        exp_avg = ps.get("exp_avg")
166
+        exp_avg_sq = ps.get("exp_avg_sq")
167
+        per_param.append(
168
+            ParamStat(
169
+                param_id=pid,
170
+                step=step_v,
171
+                exp_avg_norm=_tensor_norm(exp_avg),
172
+                exp_avg_sq_mean=_tensor_mean(exp_avg_sq),
173
+                numel=_tensor_numel(exp_avg_sq),
174
+            )
175
+        )
176
+    per_param.sort(key=lambda s: s.param_id)
177
+
178
+    return TrainingStateSnapshot(
179
+        global_step=int(state.get("global_step", 0) or 0),
180
+        epoch=float(state.get("epoch", 0.0) or 0.0),
181
+        best_val_loss=float(state.get("best_val_loss", 0.0) or 0.0),
182
+        per_param=tuple(per_param),
183
+        pinned_versions=dict(state.get("pinned_versions") or {}),
184
+        base_model_revision=state.get("base_model_revision"),
185
+        dlm_manifest_hash=state.get("dlm_manifest_hash"),
186
+        use_qlora=bool(state.get("use_qlora", False)),
187
+    )
188
+
189
+
190
+def _scalar_int(v: Any) -> int:
191
+    """torch saves ``step`` as a 0-dim tensor; coerce safely."""
192
+    if v is None:
193
+        return 0
194
+    item = getattr(v, "item", None)
195
+    if callable(item):
196
+        try:
197
+            return int(item())
198
+        except Exception:  # noqa: BLE001
199
+            return 0
200
+    try:
201
+        return int(v)
202
+    except Exception:  # noqa: BLE001
203
+        return 0
204
+
205
+
206
+def _tensor_norm(t: Any) -> float:
207
+    if t is None:
208
+        return 0.0
209
+    norm = getattr(t, "norm", None)
210
+    if callable(norm):
211
+        try:
212
+            return float(norm().item())
213
+        except Exception:  # noqa: BLE001
214
+            return 0.0
215
+    return 0.0
216
+
217
+
218
+def _tensor_mean(t: Any) -> float:
219
+    if t is None:
220
+        return 0.0
221
+    mean = getattr(t, "mean", None)
222
+    if callable(mean):
223
+        try:
224
+            return float(mean().item())
225
+        except Exception:  # noqa: BLE001
226
+            return 0.0
227
+    return 0.0
228
+
229
+
230
+def _tensor_numel(t: Any) -> int:
231
+    if t is None:
232
+        return 0
233
+    numel = getattr(t, "numel", None)
234
+    if callable(numel):
235
+        try:
236
+            return int(numel())
237
+        except Exception:  # noqa: BLE001
238
+            return 0
239
+    return 0