| 1 | """P3 / S25 — GradientGhostProbe (pre-run, cross-repo, no model load). |
| 2 | |
| 3 | Reads the optimizer-state snapshot dlm writes alongside every adapter |
| 4 | version (``training_state.pt``) and answers one practical question |
| 5 | **before any forward pass fires**: was this adapter trained long enough |
| 6 | to converge? |
| 7 | |
| 8 | Loads in ~50 ms (a 1.5B-param adapter's optimizer state is ~50 MB |
| 9 | pickle). The full adherence suite costs 30+ s; running this first as |
| 10 | a pre-flight check lets ``sway check`` short-circuit on obviously- |
| 11 | broken adapters without paying the model-load tax. |
| 12 | |
| 13 | ## Signal ladder (most → least decisive) |
| 14 | |
| 15 | 1. **``global_step < min_steps_threshold``** — primary signal. |
| 16 | Catches the 90%-case "user did ``--max-steps 5`` for a smoke test |
| 17 | and forgot to retrain." No analysis needed; verdict: FAIL. |
| 18 | 2. **NaN / zero ``exp_avg_sq``** — strong secondary signal. |
| 19 | Adam's second-moment estimate didn't accumulate any useful |
| 20 | variance, meaning gradients didn't propagate meaningfully. Often |
| 21 | co-occurs with case 1 but worth flagging independently for |
| 22 | trainings that *did* run many steps with broken loss. |
| 23 | 3. **Per-layer ratio: layer mean vs global mean of |
| 24 | ``exp_avg_sq``** — heuristic. Layers whose second-moment is |
| 25 | dramatically above the global mean still see large gradient |
| 26 | variance — they haven't converged. Verdict: WARN when a |
| 27 | significant fraction of layers cross the threshold. |
| 28 | |
| 29 | ## Why no null calibration |
| 30 | |
| 31 | Other probes z-score against a null-adapter baseline ("how much |
| 32 | signal does random noise produce?"). Gradient state has no |
| 33 | meaningful null — there's no equivalent to a "random optimizer |
| 34 | snapshot" that the LoRA's noise floor would settle into. Verdict |
| 35 | thresholds are explicit heuristics; document as "tune from user |
| 36 | feedback" rather than fake calibration math. |
| 37 | |
| 38 | ## Inputs |
| 39 | |
| 40 | The spec carries ``adapter_path`` (the dir containing |
| 41 | ``training_state.pt`` + ``adapter_model.safetensors``). Typically |
| 42 | populated by the dlm autogen bridge from |
| 43 | ``DlmHandle.adapter_path``. |
| 44 | |
| 45 | ## Verdict thresholds (all configurable in spec) |
| 46 | |
| 47 | - ``min_steps_threshold = 50`` — below this is severely undertrained. |
| 48 | - ``undertrained_layer_ratio = 2.0`` — a layer's mean ``exp_avg_sq`` |
| 49 | must be > 2× the **minimum** layer's mean to count as "still has |
| 50 | high gradient variance." We compare against the min (not the |
| 51 | global mean) because mean rises with outliers — under a global- |
| 52 | mean baseline the per-layer ratio asymptotically caps at |
| 53 | ``N/(N-K)`` and can never exceed ``ratio`` for K layers without K |
| 54 | also rising. Min-baseline gives a stable "K layers are anomalously |
| 55 | high vs the calmest layer" signal regardless of how many layers |
| 56 | spike. |
| 57 | - ``layer_failure_frac = 0.3`` — WARN if more than 30% of layers |
| 58 | cross the per-layer threshold. |
| 59 | """ |
| 60 | |
| 61 | from __future__ import annotations |
| 62 | |
| 63 | import math |
| 64 | import statistics |
| 65 | from pathlib import Path |
| 66 | from typing import ClassVar, Literal |
| 67 | |
| 68 | from pydantic import Field |
| 69 | |
| 70 | from dlm_sway.core.errors import MissingTrainingStateError, SwayError |
| 71 | from dlm_sway.core.result import ProbeResult, Verdict, safe_finalize |
| 72 | from dlm_sway.probes._param_id_mapping import ParamMappingError, map_param_ids_to_layers |
| 73 | from dlm_sway.probes._training_state import TrainingStateError, load_training_state |
| 74 | from dlm_sway.probes.base import Probe, ProbeSpec, RunContext |
| 75 | |
| 76 | |
| 77 | class GradientGhostSpec(ProbeSpec): |
| 78 | """Spec for ``kind: gradient_ghost``.""" |
| 79 | |
| 80 | kind: Literal["gradient_ghost"] = "gradient_ghost" |
| 81 | adapter_path: str |
| 82 | """Path to a dlm adapter version directory (must contain |
| 83 | ``training_state.pt``). Resolved relative to the spec file's |
| 84 | cwd via the same convention sway already uses for |
| 85 | ``models.ft.adapter``.""" |
| 86 | min_steps_threshold: int = Field(default=50, ge=1) |
| 87 | """``global_step`` below this → FAIL (severely undertrained).""" |
| 88 | undertrained_layer_ratio: float = Field(default=2.0, gt=1.0) |
| 89 | """A layer counts as 'high gradient variance' when its mean |
| 90 | ``exp_avg_sq`` exceeds ``ratio * min_layer_mean``. The min- |
| 91 | baseline (rather than global mean) is robust to outliers — see |
| 92 | the module docstring for the asymptotic-cap reasoning.""" |
| 93 | layer_failure_frac: float = Field(default=0.3, ge=0.0, le=1.0) |
| 94 | """WARN when more than this fraction of layers cross the |
| 95 | ``undertrained_layer_ratio`` threshold.""" |
| 96 | |
| 97 | |
| 98 | class GradientGhostProbe(Probe): |
| 99 | """Pre-run training-health probe.""" |
| 100 | |
| 101 | kind = "gradient_ghost" |
| 102 | spec_cls = GradientGhostSpec |
| 103 | category = "calibration" |
| 104 | needs_backend: ClassVar[bool] = False |
| 105 | |
| 106 | def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult: |
| 107 | del ctx # No backend / sections / null_stats needed. |
| 108 | assert isinstance(spec, GradientGhostSpec) |
| 109 | adapter_dir = Path(spec.adapter_path).expanduser() |
| 110 | |
| 111 | # === 1) Load + validate training state === |
| 112 | try: |
| 113 | snap = load_training_state(adapter_dir) |
| 114 | except MissingTrainingStateError as exc: |
| 115 | return ProbeResult( |
| 116 | name=spec.name, |
| 117 | kind=spec.kind, |
| 118 | verdict=Verdict.SKIP, |
| 119 | score=None, |
| 120 | message=str(exc), |
| 121 | ) |
| 122 | except (TrainingStateError, SwayError) as exc: |
| 123 | return ProbeResult( |
| 124 | name=spec.name, |
| 125 | kind=spec.kind, |
| 126 | verdict=Verdict.ERROR, |
| 127 | score=None, |
| 128 | message=str(exc), |
| 129 | ) |
| 130 | |
| 131 | # === 2) Layer grouping (best-effort; degrades to no-grouping) === |
| 132 | try: |
| 133 | grouping = map_param_ids_to_layers(adapter_dir, num_params=len(snap.per_param)) |
| 134 | layer_count = grouping.num_layers |
| 135 | params_per_layer = grouping.params_per_layer |
| 136 | except ParamMappingError: |
| 137 | grouping = None |
| 138 | layer_count = 0 |
| 139 | params_per_layer = 0 |
| 140 | |
| 141 | # === 3) Primary signal: global_step floor === |
| 142 | if snap.global_step < spec.min_steps_threshold: |
| 143 | return safe_finalize( |
| 144 | name=spec.name, |
| 145 | kind=spec.kind, |
| 146 | verdict=Verdict.FAIL, |
| 147 | score=0.0, |
| 148 | raw=float(snap.global_step), |
| 149 | z_score=None, |
| 150 | evidence={ |
| 151 | "global_step": snap.global_step, |
| 152 | "min_steps_threshold": spec.min_steps_threshold, |
| 153 | "epoch": snap.epoch, |
| 154 | "num_params": len(snap.per_param), |
| 155 | "num_layers": layer_count, |
| 156 | "best_val_loss": snap.best_val_loss |
| 157 | if math.isfinite(snap.best_val_loss) |
| 158 | else None, |
| 159 | "primary_signal": "global_step_below_threshold", |
| 160 | }, |
| 161 | message=( |
| 162 | f"severely undertrained: global_step={snap.global_step} " |
| 163 | f"< threshold {spec.min_steps_threshold}. Probe scores on " |
| 164 | f"this adapter will be unreliable; consider retraining." |
| 165 | ), |
| 166 | ) |
| 167 | |
| 168 | # === 4) NaN / zero exp_avg_sq detection === |
| 169 | finite_means = [ |
| 170 | ps.exp_avg_sq_mean for ps in snap.per_param if math.isfinite(ps.exp_avg_sq_mean) |
| 171 | ] |
| 172 | nan_or_inf_count = len(snap.per_param) - len(finite_means) |
| 173 | if not finite_means: |
| 174 | return safe_finalize( |
| 175 | name=spec.name, |
| 176 | kind=spec.kind, |
| 177 | verdict=Verdict.FAIL, |
| 178 | score=0.0, |
| 179 | raw=0.0, |
| 180 | z_score=None, |
| 181 | evidence={ |
| 182 | "global_step": snap.global_step, |
| 183 | "num_params": len(snap.per_param), |
| 184 | "num_nonfinite_exp_avg_sq": nan_or_inf_count, |
| 185 | "primary_signal": "all_optimizer_state_nan", |
| 186 | }, |
| 187 | message=( |
| 188 | "every per-param exp_avg_sq is NaN or non-finite — " |
| 189 | "training didn't propagate gradients meaningfully." |
| 190 | ), |
| 191 | ) |
| 192 | |
| 193 | # === 5) Per-layer secondary signal === |
| 194 | global_mean = statistics.fmean(finite_means) |
| 195 | per_layer_means: dict[int, float] = {} |
| 196 | per_layer_undertrained: list[int] = [] |
| 197 | baseline_min: float = 0.0 |
| 198 | |
| 199 | if grouping is not None and global_mean > 0.0: |
| 200 | # Group finite per-param means by layer index. |
| 201 | buckets: dict[int, list[float]] = {idx: [] for idx in grouping.layer_indices} |
| 202 | for ps in snap.per_param: |
| 203 | if not math.isfinite(ps.exp_avg_sq_mean): |
| 204 | continue |
| 205 | layer = grouping.layer_of.get(ps.param_id) |
| 206 | if layer is not None: |
| 207 | buckets[layer].append(ps.exp_avg_sq_mean) |
| 208 | for layer_idx, vals in buckets.items(): |
| 209 | if not vals: |
| 210 | continue |
| 211 | per_layer_means[layer_idx] = statistics.fmean(vals) |
| 212 | # Min-baseline ratio (see module docstring on why we use |
| 213 | # min instead of global mean — global-mean ratio is |
| 214 | # asymptotically capped and can't catch the case where K |
| 215 | # layers all spike together). |
| 216 | if per_layer_means: |
| 217 | baseline_min = min(per_layer_means.values()) |
| 218 | if baseline_min > 0.0: |
| 219 | for layer_idx, mean in per_layer_means.items(): |
| 220 | if mean > spec.undertrained_layer_ratio * baseline_min: |
| 221 | per_layer_undertrained.append(layer_idx) |
| 222 | |
| 223 | frac_undertrained = len(per_layer_undertrained) / layer_count if layer_count > 0 else 0.0 |
| 224 | |
| 225 | # Top-3 worst layers (highest ratio vs baseline_min) — useful |
| 226 | # evidence even when no layer crosses the threshold. |
| 227 | ranked_layers = sorted(per_layer_means.items(), key=lambda kv: -kv[1])[:3] |
| 228 | worst_layers = [ |
| 229 | {"layer": idx, "ratio": (mean / baseline_min) if baseline_min > 0 else None} |
| 230 | for idx, mean in ranked_layers |
| 231 | ] |
| 232 | |
| 233 | # === 6) Final verdict === |
| 234 | if frac_undertrained > spec.layer_failure_frac: |
| 235 | verdict = Verdict.FAIL |
| 236 | message = ( |
| 237 | f"{frac_undertrained:.0%} of layers ({len(per_layer_undertrained)}/" |
| 238 | f"{layer_count}) still show high gradient variance — training-loss " |
| 239 | f"curve likely hasn't flattened. global_step={snap.global_step}." |
| 240 | ) |
| 241 | score = 0.3 # Bottom of "partial" band. |
| 242 | elif per_layer_undertrained: |
| 243 | verdict = Verdict.WARN |
| 244 | message = ( |
| 245 | f"{frac_undertrained:.0%} of layers ({len(per_layer_undertrained)}/" |
| 246 | f"{layer_count}) above {spec.undertrained_layer_ratio:.1f}× the global " |
| 247 | f"exp_avg_sq mean — adapter may be partially undertrained but other " |
| 248 | f"probes can still produce signal. global_step={snap.global_step}." |
| 249 | ) |
| 250 | score = 0.7 |
| 251 | else: |
| 252 | verdict = Verdict.PASS |
| 253 | message = ( |
| 254 | f"global_step={snap.global_step}, no layer above " |
| 255 | f"{spec.undertrained_layer_ratio:.1f}× global exp_avg_sq mean — " |
| 256 | f"training looks converged." |
| 257 | ) |
| 258 | score = 0.9 |
| 259 | |
| 260 | return safe_finalize( |
| 261 | name=spec.name, |
| 262 | kind=spec.kind, |
| 263 | verdict=verdict, |
| 264 | score=score, |
| 265 | raw=frac_undertrained, |
| 266 | z_score=None, |
| 267 | evidence={ |
| 268 | "global_step": snap.global_step, |
| 269 | "epoch": snap.epoch, |
| 270 | "num_params": len(snap.per_param), |
| 271 | "num_layers": layer_count, |
| 272 | "params_per_layer": params_per_layer, |
| 273 | "global_mean_exp_avg_sq": global_mean, |
| 274 | "frac_layers_undertrained": frac_undertrained, |
| 275 | "num_layers_undertrained": len(per_layer_undertrained), |
| 276 | "worst_layers": worst_layers, |
| 277 | "num_nonfinite_exp_avg_sq": nan_or_inf_count, |
| 278 | "best_val_loss": snap.best_val_loss if math.isfinite(snap.best_val_loss) else None, |
| 279 | "use_qlora": snap.use_qlora, |
| 280 | }, |
| 281 | message=message, |
| 282 | ) |