Python · 12132 bytes Raw Blame History
1 """P3 / S25 — GradientGhostProbe (pre-run, cross-repo, no model load).
2
3 Reads the optimizer-state snapshot dlm writes alongside every adapter
4 version (``training_state.pt``) and answers one practical question
5 **before any forward pass fires**: was this adapter trained long enough
6 to converge?
7
8 Loads in ~50 ms (a 1.5B-param adapter's optimizer state is ~50 MB
9 pickle). The full adherence suite costs 30+ s; running this first as
10 a pre-flight check lets ``sway check`` short-circuit on obviously-
11 broken adapters without paying the model-load tax.
12
13 ## Signal ladder (most → least decisive)
14
15 1. **``global_step < min_steps_threshold``** — primary signal.
16 Catches the 90%-case "user did ``--max-steps 5`` for a smoke test
17 and forgot to retrain." No analysis needed; verdict: FAIL.
18 2. **NaN / zero ``exp_avg_sq``** — strong secondary signal.
19 Adam's second-moment estimate didn't accumulate any useful
20 variance, meaning gradients didn't propagate meaningfully. Often
21 co-occurs with case 1 but worth flagging independently for
22 trainings that *did* run many steps with broken loss.
23 3. **Per-layer ratio: layer mean vs global mean of
24 ``exp_avg_sq``** — heuristic. Layers whose second-moment is
25 dramatically above the global mean still see large gradient
26 variance — they haven't converged. Verdict: WARN when a
27 significant fraction of layers cross the threshold.
28
29 ## Why no null calibration
30
31 Other probes z-score against a null-adapter baseline ("how much
32 signal does random noise produce?"). Gradient state has no
33 meaningful null — there's no equivalent to a "random optimizer
34 snapshot" that the LoRA's noise floor would settle into. Verdict
35 thresholds are explicit heuristics; document as "tune from user
36 feedback" rather than fake calibration math.
37
38 ## Inputs
39
40 The spec carries ``adapter_path`` (the dir containing
41 ``training_state.pt`` + ``adapter_model.safetensors``). Typically
42 populated by the dlm autogen bridge from
43 ``DlmHandle.adapter_path``.
44
45 ## Verdict thresholds (all configurable in spec)
46
47 - ``min_steps_threshold = 50`` — below this is severely undertrained.
48 - ``undertrained_layer_ratio = 2.0`` — a layer's mean ``exp_avg_sq``
49 must be > 2× the **minimum** layer's mean to count as "still has
50 high gradient variance." We compare against the min (not the
51 global mean) because mean rises with outliers — under a global-
52 mean baseline the per-layer ratio asymptotically caps at
53 ``N/(N-K)`` and can never exceed ``ratio`` for K layers without K
54 also rising. Min-baseline gives a stable "K layers are anomalously
55 high vs the calmest layer" signal regardless of how many layers
56 spike.
57 - ``layer_failure_frac = 0.3`` — WARN if more than 30% of layers
58 cross the per-layer threshold.
59 """
60
61 from __future__ import annotations
62
63 import math
64 import statistics
65 from pathlib import Path
66 from typing import ClassVar, Literal
67
68 from pydantic import Field
69
70 from dlm_sway.core.errors import MissingTrainingStateError, SwayError
71 from dlm_sway.core.result import ProbeResult, Verdict, safe_finalize
72 from dlm_sway.probes._param_id_mapping import ParamMappingError, map_param_ids_to_layers
73 from dlm_sway.probes._training_state import TrainingStateError, load_training_state
74 from dlm_sway.probes.base import Probe, ProbeSpec, RunContext
75
76
77 class GradientGhostSpec(ProbeSpec):
78 """Spec for ``kind: gradient_ghost``."""
79
80 kind: Literal["gradient_ghost"] = "gradient_ghost"
81 adapter_path: str
82 """Path to a dlm adapter version directory (must contain
83 ``training_state.pt``). Resolved relative to the spec file's
84 cwd via the same convention sway already uses for
85 ``models.ft.adapter``."""
86 min_steps_threshold: int = Field(default=50, ge=1)
87 """``global_step`` below this → FAIL (severely undertrained)."""
88 undertrained_layer_ratio: float = Field(default=2.0, gt=1.0)
89 """A layer counts as 'high gradient variance' when its mean
90 ``exp_avg_sq`` exceeds ``ratio * min_layer_mean``. The min-
91 baseline (rather than global mean) is robust to outliers — see
92 the module docstring for the asymptotic-cap reasoning."""
93 layer_failure_frac: float = Field(default=0.3, ge=0.0, le=1.0)
94 """WARN when more than this fraction of layers cross the
95 ``undertrained_layer_ratio`` threshold."""
96
97
98 class GradientGhostProbe(Probe):
99 """Pre-run training-health probe."""
100
101 kind = "gradient_ghost"
102 spec_cls = GradientGhostSpec
103 category = "calibration"
104 needs_backend: ClassVar[bool] = False
105
106 def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult:
107 del ctx # No backend / sections / null_stats needed.
108 assert isinstance(spec, GradientGhostSpec)
109 adapter_dir = Path(spec.adapter_path).expanduser()
110
111 # === 1) Load + validate training state ===
112 try:
113 snap = load_training_state(adapter_dir)
114 except MissingTrainingStateError as exc:
115 return ProbeResult(
116 name=spec.name,
117 kind=spec.kind,
118 verdict=Verdict.SKIP,
119 score=None,
120 message=str(exc),
121 )
122 except (TrainingStateError, SwayError) as exc:
123 return ProbeResult(
124 name=spec.name,
125 kind=spec.kind,
126 verdict=Verdict.ERROR,
127 score=None,
128 message=str(exc),
129 )
130
131 # === 2) Layer grouping (best-effort; degrades to no-grouping) ===
132 try:
133 grouping = map_param_ids_to_layers(adapter_dir, num_params=len(snap.per_param))
134 layer_count = grouping.num_layers
135 params_per_layer = grouping.params_per_layer
136 except ParamMappingError:
137 grouping = None
138 layer_count = 0
139 params_per_layer = 0
140
141 # === 3) Primary signal: global_step floor ===
142 if snap.global_step < spec.min_steps_threshold:
143 return safe_finalize(
144 name=spec.name,
145 kind=spec.kind,
146 verdict=Verdict.FAIL,
147 score=0.0,
148 raw=float(snap.global_step),
149 z_score=None,
150 evidence={
151 "global_step": snap.global_step,
152 "min_steps_threshold": spec.min_steps_threshold,
153 "epoch": snap.epoch,
154 "num_params": len(snap.per_param),
155 "num_layers": layer_count,
156 "best_val_loss": snap.best_val_loss
157 if math.isfinite(snap.best_val_loss)
158 else None,
159 "primary_signal": "global_step_below_threshold",
160 },
161 message=(
162 f"severely undertrained: global_step={snap.global_step} "
163 f"< threshold {spec.min_steps_threshold}. Probe scores on "
164 f"this adapter will be unreliable; consider retraining."
165 ),
166 )
167
168 # === 4) NaN / zero exp_avg_sq detection ===
169 finite_means = [
170 ps.exp_avg_sq_mean for ps in snap.per_param if math.isfinite(ps.exp_avg_sq_mean)
171 ]
172 nan_or_inf_count = len(snap.per_param) - len(finite_means)
173 if not finite_means:
174 return safe_finalize(
175 name=spec.name,
176 kind=spec.kind,
177 verdict=Verdict.FAIL,
178 score=0.0,
179 raw=0.0,
180 z_score=None,
181 evidence={
182 "global_step": snap.global_step,
183 "num_params": len(snap.per_param),
184 "num_nonfinite_exp_avg_sq": nan_or_inf_count,
185 "primary_signal": "all_optimizer_state_nan",
186 },
187 message=(
188 "every per-param exp_avg_sq is NaN or non-finite — "
189 "training didn't propagate gradients meaningfully."
190 ),
191 )
192
193 # === 5) Per-layer secondary signal ===
194 global_mean = statistics.fmean(finite_means)
195 per_layer_means: dict[int, float] = {}
196 per_layer_undertrained: list[int] = []
197 baseline_min: float = 0.0
198
199 if grouping is not None and global_mean > 0.0:
200 # Group finite per-param means by layer index.
201 buckets: dict[int, list[float]] = {idx: [] for idx in grouping.layer_indices}
202 for ps in snap.per_param:
203 if not math.isfinite(ps.exp_avg_sq_mean):
204 continue
205 layer = grouping.layer_of.get(ps.param_id)
206 if layer is not None:
207 buckets[layer].append(ps.exp_avg_sq_mean)
208 for layer_idx, vals in buckets.items():
209 if not vals:
210 continue
211 per_layer_means[layer_idx] = statistics.fmean(vals)
212 # Min-baseline ratio (see module docstring on why we use
213 # min instead of global mean — global-mean ratio is
214 # asymptotically capped and can't catch the case where K
215 # layers all spike together).
216 if per_layer_means:
217 baseline_min = min(per_layer_means.values())
218 if baseline_min > 0.0:
219 for layer_idx, mean in per_layer_means.items():
220 if mean > spec.undertrained_layer_ratio * baseline_min:
221 per_layer_undertrained.append(layer_idx)
222
223 frac_undertrained = len(per_layer_undertrained) / layer_count if layer_count > 0 else 0.0
224
225 # Top-3 worst layers (highest ratio vs baseline_min) — useful
226 # evidence even when no layer crosses the threshold.
227 ranked_layers = sorted(per_layer_means.items(), key=lambda kv: -kv[1])[:3]
228 worst_layers = [
229 {"layer": idx, "ratio": (mean / baseline_min) if baseline_min > 0 else None}
230 for idx, mean in ranked_layers
231 ]
232
233 # === 6) Final verdict ===
234 if frac_undertrained > spec.layer_failure_frac:
235 verdict = Verdict.FAIL
236 message = (
237 f"{frac_undertrained:.0%} of layers ({len(per_layer_undertrained)}/"
238 f"{layer_count}) still show high gradient variance — training-loss "
239 f"curve likely hasn't flattened. global_step={snap.global_step}."
240 )
241 score = 0.3 # Bottom of "partial" band.
242 elif per_layer_undertrained:
243 verdict = Verdict.WARN
244 message = (
245 f"{frac_undertrained:.0%} of layers ({len(per_layer_undertrained)}/"
246 f"{layer_count}) above {spec.undertrained_layer_ratio:.1f}× the global "
247 f"exp_avg_sq mean — adapter may be partially undertrained but other "
248 f"probes can still produce signal. global_step={snap.global_step}."
249 )
250 score = 0.7
251 else:
252 verdict = Verdict.PASS
253 message = (
254 f"global_step={snap.global_step}, no layer above "
255 f"{spec.undertrained_layer_ratio:.1f}× global exp_avg_sq mean — "
256 f"training looks converged."
257 )
258 score = 0.9
259
260 return safe_finalize(
261 name=spec.name,
262 kind=spec.kind,
263 verdict=verdict,
264 score=score,
265 raw=frac_undertrained,
266 z_score=None,
267 evidence={
268 "global_step": snap.global_step,
269 "epoch": snap.epoch,
270 "num_params": len(snap.per_param),
271 "num_layers": layer_count,
272 "params_per_layer": params_per_layer,
273 "global_mean_exp_avg_sq": global_mean,
274 "frac_layers_undertrained": frac_undertrained,
275 "num_layers_undertrained": len(per_layer_undertrained),
276 "worst_layers": worst_layers,
277 "num_nonfinite_exp_avg_sq": nan_or_inf_count,
278 "best_val_loss": snap.best_val_loss if math.isfinite(snap.best_val_loss) else None,
279 "use_qlora": snap.use_qlora,
280 },
281 message=message,
282 )