tenseleyflow/sway / 4f12f3b

Browse files

probes/gradient_ghost: pre-run training-health probe (S25 P4)

Authored by mfwolffe <wolffemf@dukes.jmu.edu>
SHA
4f12f3b76f32ac9288179752b81d87bf70c8c2b1
Parents
74d94e2
Tree
e1b6ce0

2 changed files

StatusFile+-
M src/dlm_sway/probes/__init__.py 1 0
A src/dlm_sway/probes/gradient_ghost.py 268 0
src/dlm_sway/probes/__init__.pymodified
@@ -19,6 +19,7 @@ from dlm_sway.probes import ( # noqa: F401 — imports register the probes
1919
     cluster_kl,
2020
     delta_kl,
2121
     external_perplexity,
22
+    gradient_ghost,
2223
     leakage,
2324
     null_adapter,
2425
     paraphrase_invariance,
src/dlm_sway/probes/gradient_ghost.pyadded
@@ -0,0 +1,268 @@
1
+"""P3 / S25 — GradientGhostProbe (pre-run, cross-repo, no model load).
2
+
3
+Reads the optimizer-state snapshot dlm writes alongside every adapter
4
+version (``training_state.pt``) and answers one practical question
5
+**before any forward pass fires**: was this adapter trained long enough
6
+to converge?
7
+
8
+Loads in ~50 ms (a 1.5B-param adapter's optimizer state is ~50 MB
9
+pickle). The full adherence suite costs 30+ s; running this first as
10
+a pre-flight check lets ``sway check`` short-circuit on obviously-
11
+broken adapters without paying the model-load tax.
12
+
13
+## Signal ladder (most → least decisive)
14
+
15
+1. **``global_step < min_steps_threshold``** — primary signal.
16
+   Catches the 90%-case "user did ``--max-steps 5`` for a smoke test
17
+   and forgot to retrain." No analysis needed; verdict: FAIL.
18
+2. **NaN / zero ``exp_avg_sq``** — strong secondary signal.
19
+   Adam's second-moment estimate didn't accumulate any useful
20
+   variance, meaning gradients didn't propagate meaningfully. Often
21
+   co-occurs with case 1 but worth flagging independently for
22
+   trainings that *did* run many steps with broken loss.
23
+3. **Per-layer ratio: layer mean vs global mean of
24
+   ``exp_avg_sq``** — heuristic. Layers whose second-moment is
25
+   dramatically above the global mean still see large gradient
26
+   variance — they haven't converged. Verdict: WARN when a
27
+   significant fraction of layers cross the threshold.
28
+
29
+## Why no null calibration
30
+
31
+Other probes z-score against a null-adapter baseline ("how much
32
+signal does random noise produce?"). Gradient state has no
33
+meaningful null — there's no equivalent to a "random optimizer
34
+snapshot" that the LoRA's noise floor would settle into. Verdict
35
+thresholds are explicit heuristics; document as "tune from user
36
+feedback" rather than fake calibration math.
37
+
38
+## Inputs
39
+
40
+The spec carries ``adapter_path`` (the dir containing
41
+``training_state.pt`` + ``adapter_model.safetensors``). Typically
42
+populated by the dlm autogen bridge from
43
+``DlmHandle.adapter_path``.
44
+
45
+## Verdict thresholds (all configurable in spec)
46
+
47
+- ``min_steps_threshold = 50`` — below this is severely undertrained.
48
+- ``undertrained_layer_ratio = 2.0`` — a layer's mean ``exp_avg_sq``
49
+  must be > 2× the global mean to count as "still has high gradient
50
+  variance." Multiplicative threshold (relative, not absolute) so
51
+  the probe is architecture-agnostic.
52
+- ``layer_failure_frac = 0.3`` — WARN if more than 30% of layers
53
+  cross the per-layer threshold.
54
+"""
55
+
56
+from __future__ import annotations
57
+
58
+import math
59
+import statistics
60
+from pathlib import Path
61
+from typing import ClassVar, Literal
62
+
63
+from pydantic import Field
64
+
65
+from dlm_sway.core.errors import MissingTrainingStateError, SwayError
66
+from dlm_sway.core.result import ProbeResult, Verdict, safe_finalize
67
+from dlm_sway.probes._param_id_mapping import ParamMappingError, map_param_ids_to_layers
68
+from dlm_sway.probes._training_state import TrainingStateError, load_training_state
69
+from dlm_sway.probes.base import Probe, ProbeSpec, RunContext
70
+
71
+
72
+class GradientGhostSpec(ProbeSpec):
73
+    """Spec for ``kind: gradient_ghost``."""
74
+
75
+    kind: Literal["gradient_ghost"] = "gradient_ghost"
76
+    adapter_path: str
77
+    """Path to a dlm adapter version directory (must contain
78
+    ``training_state.pt``). Resolved relative to the spec file's
79
+    cwd via the same convention sway already uses for
80
+    ``models.ft.adapter``."""
81
+    min_steps_threshold: int = Field(default=50, ge=1)
82
+    """``global_step`` below this → FAIL (severely undertrained)."""
83
+    undertrained_layer_ratio: float = Field(default=2.0, gt=1.0)
84
+    """A layer counts as 'high gradient variance' when its mean
85
+    ``exp_avg_sq`` exceeds ``ratio * global_mean``. Strictly > 1
86
+    (a value of 1 would always flag half the layers)."""
87
+    layer_failure_frac: float = Field(default=0.3, ge=0.0, le=1.0)
88
+    """WARN when more than this fraction of layers cross the
89
+    ``undertrained_layer_ratio`` threshold."""
90
+
91
+
92
+class GradientGhostProbe(Probe):
93
+    """Pre-run training-health probe."""
94
+
95
+    kind = "gradient_ghost"
96
+    spec_cls = GradientGhostSpec
97
+    category = "calibration"
98
+    needs_backend: ClassVar[bool] = False
99
+
100
+    def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult:
101
+        del ctx  # No backend / sections / null_stats needed.
102
+        assert isinstance(spec, GradientGhostSpec)
103
+        adapter_dir = Path(spec.adapter_path).expanduser()
104
+
105
+        # === 1) Load + validate training state ===
106
+        try:
107
+            snap = load_training_state(adapter_dir)
108
+        except MissingTrainingStateError as exc:
109
+            return ProbeResult(
110
+                name=spec.name,
111
+                kind=spec.kind,
112
+                verdict=Verdict.SKIP,
113
+                score=None,
114
+                message=str(exc),
115
+            )
116
+        except (TrainingStateError, SwayError) as exc:
117
+            return ProbeResult(
118
+                name=spec.name,
119
+                kind=spec.kind,
120
+                verdict=Verdict.ERROR,
121
+                score=None,
122
+                message=str(exc),
123
+            )
124
+
125
+        # === 2) Layer grouping (best-effort; degrades to no-grouping) ===
126
+        try:
127
+            grouping = map_param_ids_to_layers(adapter_dir, num_params=len(snap.per_param))
128
+            layer_count = grouping.num_layers
129
+            params_per_layer = grouping.params_per_layer
130
+        except ParamMappingError:
131
+            grouping = None
132
+            layer_count = 0
133
+            params_per_layer = 0
134
+
135
+        # === 3) Primary signal: global_step floor ===
136
+        if snap.global_step < spec.min_steps_threshold:
137
+            return safe_finalize(
138
+                name=spec.name,
139
+                kind=spec.kind,
140
+                verdict=Verdict.FAIL,
141
+                score=0.0,
142
+                raw=float(snap.global_step),
143
+                z_score=None,
144
+                evidence={
145
+                    "global_step": snap.global_step,
146
+                    "min_steps_threshold": spec.min_steps_threshold,
147
+                    "epoch": snap.epoch,
148
+                    "num_params": len(snap.per_param),
149
+                    "num_layers": layer_count,
150
+                    "best_val_loss": snap.best_val_loss
151
+                    if math.isfinite(snap.best_val_loss)
152
+                    else None,
153
+                    "primary_signal": "global_step_below_threshold",
154
+                },
155
+                message=(
156
+                    f"severely undertrained: global_step={snap.global_step} "
157
+                    f"< threshold {spec.min_steps_threshold}. Probe scores on "
158
+                    f"this adapter will be unreliable; consider retraining."
159
+                ),
160
+            )
161
+
162
+        # === 4) NaN / zero exp_avg_sq detection ===
163
+        finite_means = [
164
+            ps.exp_avg_sq_mean for ps in snap.per_param if math.isfinite(ps.exp_avg_sq_mean)
165
+        ]
166
+        nan_or_inf_count = len(snap.per_param) - len(finite_means)
167
+        if not finite_means:
168
+            return safe_finalize(
169
+                name=spec.name,
170
+                kind=spec.kind,
171
+                verdict=Verdict.FAIL,
172
+                score=0.0,
173
+                raw=0.0,
174
+                z_score=None,
175
+                evidence={
176
+                    "global_step": snap.global_step,
177
+                    "num_params": len(snap.per_param),
178
+                    "num_nonfinite_exp_avg_sq": nan_or_inf_count,
179
+                    "primary_signal": "all_optimizer_state_nan",
180
+                },
181
+                message=(
182
+                    "every per-param exp_avg_sq is NaN or non-finite — "
183
+                    "training didn't propagate gradients meaningfully."
184
+                ),
185
+            )
186
+
187
+        # === 5) Per-layer secondary signal ===
188
+        global_mean = statistics.fmean(finite_means)
189
+        per_layer_means: dict[int, float] = {}
190
+        per_layer_undertrained: list[int] = []
191
+
192
+        if grouping is not None and global_mean > 0.0:
193
+            # Group finite per-param means by layer index.
194
+            buckets: dict[int, list[float]] = {idx: [] for idx in grouping.layer_indices}
195
+            for ps in snap.per_param:
196
+                if not math.isfinite(ps.exp_avg_sq_mean):
197
+                    continue
198
+                layer = grouping.layer_of.get(ps.param_id)
199
+                if layer is not None:
200
+                    buckets[layer].append(ps.exp_avg_sq_mean)
201
+            for layer_idx, vals in buckets.items():
202
+                if not vals:
203
+                    continue
204
+                layer_mean = statistics.fmean(vals)
205
+                per_layer_means[layer_idx] = layer_mean
206
+                if layer_mean > spec.undertrained_layer_ratio * global_mean:
207
+                    per_layer_undertrained.append(layer_idx)
208
+
209
+        frac_undertrained = len(per_layer_undertrained) / layer_count if layer_count > 0 else 0.0
210
+
211
+        # Top-3 worst layers (highest ratio) — useful evidence even
212
+        # when no layer crosses the threshold.
213
+        ranked_layers = sorted(per_layer_means.items(), key=lambda kv: -kv[1])[:3]
214
+        worst_layers = [
215
+            {"layer": idx, "ratio": (mean / global_mean) if global_mean > 0 else None}
216
+            for idx, mean in ranked_layers
217
+        ]
218
+
219
+        # === 6) Final verdict ===
220
+        if frac_undertrained > spec.layer_failure_frac:
221
+            verdict = Verdict.FAIL
222
+            message = (
223
+                f"{frac_undertrained:.0%} of layers ({len(per_layer_undertrained)}/"
224
+                f"{layer_count}) still show high gradient variance — training-loss "
225
+                f"curve likely hasn't flattened. global_step={snap.global_step}."
226
+            )
227
+            score = 0.3  # Bottom of "partial" band.
228
+        elif per_layer_undertrained:
229
+            verdict = Verdict.WARN
230
+            message = (
231
+                f"{frac_undertrained:.0%} of layers ({len(per_layer_undertrained)}/"
232
+                f"{layer_count}) above {spec.undertrained_layer_ratio:.1f}× the global "
233
+                f"exp_avg_sq mean — adapter may be partially undertrained but other "
234
+                f"probes can still produce signal. global_step={snap.global_step}."
235
+            )
236
+            score = 0.7
237
+        else:
238
+            verdict = Verdict.PASS
239
+            message = (
240
+                f"global_step={snap.global_step}, no layer above "
241
+                f"{spec.undertrained_layer_ratio:.1f}× global exp_avg_sq mean — "
242
+                f"training looks converged."
243
+            )
244
+            score = 0.9
245
+
246
+        return safe_finalize(
247
+            name=spec.name,
248
+            kind=spec.kind,
249
+            verdict=verdict,
250
+            score=score,
251
+            raw=frac_undertrained,
252
+            z_score=None,
253
+            evidence={
254
+                "global_step": snap.global_step,
255
+                "epoch": snap.epoch,
256
+                "num_params": len(snap.per_param),
257
+                "num_layers": layer_count,
258
+                "params_per_layer": params_per_layer,
259
+                "global_mean_exp_avg_sq": global_mean,
260
+                "frac_layers_undertrained": frac_undertrained,
261
+                "num_layers_undertrained": len(per_layer_undertrained),
262
+                "worst_layers": worst_layers,
263
+                "num_nonfinite_exp_avg_sq": nan_or_inf_count,
264
+                "best_val_loss": snap.best_val_loss if math.isfinite(snap.best_val_loss) else None,
265
+                "use_qlora": snap.use_qlora,
266
+            },
267
+            message=message,
268
+        )