@@ -0,0 +1,268 @@ |
| 1 | +"""P3 / S25 — GradientGhostProbe (pre-run, cross-repo, no model load). |
| 2 | + |
| 3 | +Reads the optimizer-state snapshot dlm writes alongside every adapter |
| 4 | +version (``training_state.pt``) and answers one practical question |
| 5 | +**before any forward pass fires**: was this adapter trained long enough |
| 6 | +to converge? |
| 7 | + |
| 8 | +Loads in ~50 ms (a 1.5B-param adapter's optimizer state is ~50 MB |
| 9 | +pickle). The full adherence suite costs 30+ s; running this first as |
| 10 | +a pre-flight check lets ``sway check`` short-circuit on obviously- |
| 11 | +broken adapters without paying the model-load tax. |
| 12 | + |
| 13 | +## Signal ladder (most → least decisive) |
| 14 | + |
| 15 | +1. **``global_step < min_steps_threshold``** — primary signal. |
| 16 | + Catches the 90%-case "user did ``--max-steps 5`` for a smoke test |
| 17 | + and forgot to retrain." No analysis needed; verdict: FAIL. |
| 18 | +2. **NaN / zero ``exp_avg_sq``** — strong secondary signal. |
| 19 | + Adam's second-moment estimate didn't accumulate any useful |
| 20 | + variance, meaning gradients didn't propagate meaningfully. Often |
| 21 | + co-occurs with case 1 but worth flagging independently for |
| 22 | + trainings that *did* run many steps with broken loss. |
| 23 | +3. **Per-layer ratio: layer mean vs global mean of |
| 24 | + ``exp_avg_sq``** — heuristic. Layers whose second-moment is |
| 25 | + dramatically above the global mean still see large gradient |
| 26 | + variance — they haven't converged. Verdict: WARN when a |
| 27 | + significant fraction of layers cross the threshold. |
| 28 | + |
| 29 | +## Why no null calibration |
| 30 | + |
| 31 | +Other probes z-score against a null-adapter baseline ("how much |
| 32 | +signal does random noise produce?"). Gradient state has no |
| 33 | +meaningful null — there's no equivalent to a "random optimizer |
| 34 | +snapshot" that the LoRA's noise floor would settle into. Verdict |
| 35 | +thresholds are explicit heuristics; document as "tune from user |
| 36 | +feedback" rather than fake calibration math. |
| 37 | + |
| 38 | +## Inputs |
| 39 | + |
| 40 | +The spec carries ``adapter_path`` (the dir containing |
| 41 | +``training_state.pt`` + ``adapter_model.safetensors``). Typically |
| 42 | +populated by the dlm autogen bridge from |
| 43 | +``DlmHandle.adapter_path``. |
| 44 | + |
| 45 | +## Verdict thresholds (all configurable in spec) |
| 46 | + |
| 47 | +- ``min_steps_threshold = 50`` — below this is severely undertrained. |
| 48 | +- ``undertrained_layer_ratio = 2.0`` — a layer's mean ``exp_avg_sq`` |
| 49 | + must be > 2× the global mean to count as "still has high gradient |
| 50 | + variance." Multiplicative threshold (relative, not absolute) so |
| 51 | + the probe is architecture-agnostic. |
| 52 | +- ``layer_failure_frac = 0.3`` — WARN if more than 30% of layers |
| 53 | + cross the per-layer threshold. |
| 54 | +""" |
| 55 | + |
| 56 | +from __future__ import annotations |
| 57 | + |
| 58 | +import math |
| 59 | +import statistics |
| 60 | +from pathlib import Path |
| 61 | +from typing import ClassVar, Literal |
| 62 | + |
| 63 | +from pydantic import Field |
| 64 | + |
| 65 | +from dlm_sway.core.errors import MissingTrainingStateError, SwayError |
| 66 | +from dlm_sway.core.result import ProbeResult, Verdict, safe_finalize |
| 67 | +from dlm_sway.probes._param_id_mapping import ParamMappingError, map_param_ids_to_layers |
| 68 | +from dlm_sway.probes._training_state import TrainingStateError, load_training_state |
| 69 | +from dlm_sway.probes.base import Probe, ProbeSpec, RunContext |
| 70 | + |
| 71 | + |
| 72 | +class GradientGhostSpec(ProbeSpec): |
| 73 | + """Spec for ``kind: gradient_ghost``.""" |
| 74 | + |
| 75 | + kind: Literal["gradient_ghost"] = "gradient_ghost" |
| 76 | + adapter_path: str |
| 77 | + """Path to a dlm adapter version directory (must contain |
| 78 | + ``training_state.pt``). Resolved relative to the spec file's |
| 79 | + cwd via the same convention sway already uses for |
| 80 | + ``models.ft.adapter``.""" |
| 81 | + min_steps_threshold: int = Field(default=50, ge=1) |
| 82 | + """``global_step`` below this → FAIL (severely undertrained).""" |
| 83 | + undertrained_layer_ratio: float = Field(default=2.0, gt=1.0) |
| 84 | + """A layer counts as 'high gradient variance' when its mean |
| 85 | + ``exp_avg_sq`` exceeds ``ratio * global_mean``. Strictly > 1 |
| 86 | + (a value of 1 would always flag half the layers).""" |
| 87 | + layer_failure_frac: float = Field(default=0.3, ge=0.0, le=1.0) |
| 88 | + """WARN when more than this fraction of layers cross the |
| 89 | + ``undertrained_layer_ratio`` threshold.""" |
| 90 | + |
| 91 | + |
| 92 | +class GradientGhostProbe(Probe): |
| 93 | + """Pre-run training-health probe.""" |
| 94 | + |
| 95 | + kind = "gradient_ghost" |
| 96 | + spec_cls = GradientGhostSpec |
| 97 | + category = "calibration" |
| 98 | + needs_backend: ClassVar[bool] = False |
| 99 | + |
| 100 | + def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult: |
| 101 | + del ctx # No backend / sections / null_stats needed. |
| 102 | + assert isinstance(spec, GradientGhostSpec) |
| 103 | + adapter_dir = Path(spec.adapter_path).expanduser() |
| 104 | + |
| 105 | + # === 1) Load + validate training state === |
| 106 | + try: |
| 107 | + snap = load_training_state(adapter_dir) |
| 108 | + except MissingTrainingStateError as exc: |
| 109 | + return ProbeResult( |
| 110 | + name=spec.name, |
| 111 | + kind=spec.kind, |
| 112 | + verdict=Verdict.SKIP, |
| 113 | + score=None, |
| 114 | + message=str(exc), |
| 115 | + ) |
| 116 | + except (TrainingStateError, SwayError) as exc: |
| 117 | + return ProbeResult( |
| 118 | + name=spec.name, |
| 119 | + kind=spec.kind, |
| 120 | + verdict=Verdict.ERROR, |
| 121 | + score=None, |
| 122 | + message=str(exc), |
| 123 | + ) |
| 124 | + |
| 125 | + # === 2) Layer grouping (best-effort; degrades to no-grouping) === |
| 126 | + try: |
| 127 | + grouping = map_param_ids_to_layers(adapter_dir, num_params=len(snap.per_param)) |
| 128 | + layer_count = grouping.num_layers |
| 129 | + params_per_layer = grouping.params_per_layer |
| 130 | + except ParamMappingError: |
| 131 | + grouping = None |
| 132 | + layer_count = 0 |
| 133 | + params_per_layer = 0 |
| 134 | + |
| 135 | + # === 3) Primary signal: global_step floor === |
| 136 | + if snap.global_step < spec.min_steps_threshold: |
| 137 | + return safe_finalize( |
| 138 | + name=spec.name, |
| 139 | + kind=spec.kind, |
| 140 | + verdict=Verdict.FAIL, |
| 141 | + score=0.0, |
| 142 | + raw=float(snap.global_step), |
| 143 | + z_score=None, |
| 144 | + evidence={ |
| 145 | + "global_step": snap.global_step, |
| 146 | + "min_steps_threshold": spec.min_steps_threshold, |
| 147 | + "epoch": snap.epoch, |
| 148 | + "num_params": len(snap.per_param), |
| 149 | + "num_layers": layer_count, |
| 150 | + "best_val_loss": snap.best_val_loss |
| 151 | + if math.isfinite(snap.best_val_loss) |
| 152 | + else None, |
| 153 | + "primary_signal": "global_step_below_threshold", |
| 154 | + }, |
| 155 | + message=( |
| 156 | + f"severely undertrained: global_step={snap.global_step} " |
| 157 | + f"< threshold {spec.min_steps_threshold}. Probe scores on " |
| 158 | + f"this adapter will be unreliable; consider retraining." |
| 159 | + ), |
| 160 | + ) |
| 161 | + |
| 162 | + # === 4) NaN / zero exp_avg_sq detection === |
| 163 | + finite_means = [ |
| 164 | + ps.exp_avg_sq_mean for ps in snap.per_param if math.isfinite(ps.exp_avg_sq_mean) |
| 165 | + ] |
| 166 | + nan_or_inf_count = len(snap.per_param) - len(finite_means) |
| 167 | + if not finite_means: |
| 168 | + return safe_finalize( |
| 169 | + name=spec.name, |
| 170 | + kind=spec.kind, |
| 171 | + verdict=Verdict.FAIL, |
| 172 | + score=0.0, |
| 173 | + raw=0.0, |
| 174 | + z_score=None, |
| 175 | + evidence={ |
| 176 | + "global_step": snap.global_step, |
| 177 | + "num_params": len(snap.per_param), |
| 178 | + "num_nonfinite_exp_avg_sq": nan_or_inf_count, |
| 179 | + "primary_signal": "all_optimizer_state_nan", |
| 180 | + }, |
| 181 | + message=( |
| 182 | + "every per-param exp_avg_sq is NaN or non-finite — " |
| 183 | + "training didn't propagate gradients meaningfully." |
| 184 | + ), |
| 185 | + ) |
| 186 | + |
| 187 | + # === 5) Per-layer secondary signal === |
| 188 | + global_mean = statistics.fmean(finite_means) |
| 189 | + per_layer_means: dict[int, float] = {} |
| 190 | + per_layer_undertrained: list[int] = [] |
| 191 | + |
| 192 | + if grouping is not None and global_mean > 0.0: |
| 193 | + # Group finite per-param means by layer index. |
| 194 | + buckets: dict[int, list[float]] = {idx: [] for idx in grouping.layer_indices} |
| 195 | + for ps in snap.per_param: |
| 196 | + if not math.isfinite(ps.exp_avg_sq_mean): |
| 197 | + continue |
| 198 | + layer = grouping.layer_of.get(ps.param_id) |
| 199 | + if layer is not None: |
| 200 | + buckets[layer].append(ps.exp_avg_sq_mean) |
| 201 | + for layer_idx, vals in buckets.items(): |
| 202 | + if not vals: |
| 203 | + continue |
| 204 | + layer_mean = statistics.fmean(vals) |
| 205 | + per_layer_means[layer_idx] = layer_mean |
| 206 | + if layer_mean > spec.undertrained_layer_ratio * global_mean: |
| 207 | + per_layer_undertrained.append(layer_idx) |
| 208 | + |
| 209 | + frac_undertrained = len(per_layer_undertrained) / layer_count if layer_count > 0 else 0.0 |
| 210 | + |
| 211 | + # Top-3 worst layers (highest ratio) — useful evidence even |
| 212 | + # when no layer crosses the threshold. |
| 213 | + ranked_layers = sorted(per_layer_means.items(), key=lambda kv: -kv[1])[:3] |
| 214 | + worst_layers = [ |
| 215 | + {"layer": idx, "ratio": (mean / global_mean) if global_mean > 0 else None} |
| 216 | + for idx, mean in ranked_layers |
| 217 | + ] |
| 218 | + |
| 219 | + # === 6) Final verdict === |
| 220 | + if frac_undertrained > spec.layer_failure_frac: |
| 221 | + verdict = Verdict.FAIL |
| 222 | + message = ( |
| 223 | + f"{frac_undertrained:.0%} of layers ({len(per_layer_undertrained)}/" |
| 224 | + f"{layer_count}) still show high gradient variance — training-loss " |
| 225 | + f"curve likely hasn't flattened. global_step={snap.global_step}." |
| 226 | + ) |
| 227 | + score = 0.3 # Bottom of "partial" band. |
| 228 | + elif per_layer_undertrained: |
| 229 | + verdict = Verdict.WARN |
| 230 | + message = ( |
| 231 | + f"{frac_undertrained:.0%} of layers ({len(per_layer_undertrained)}/" |
| 232 | + f"{layer_count}) above {spec.undertrained_layer_ratio:.1f}× the global " |
| 233 | + f"exp_avg_sq mean — adapter may be partially undertrained but other " |
| 234 | + f"probes can still produce signal. global_step={snap.global_step}." |
| 235 | + ) |
| 236 | + score = 0.7 |
| 237 | + else: |
| 238 | + verdict = Verdict.PASS |
| 239 | + message = ( |
| 240 | + f"global_step={snap.global_step}, no layer above " |
| 241 | + f"{spec.undertrained_layer_ratio:.1f}× global exp_avg_sq mean — " |
| 242 | + f"training looks converged." |
| 243 | + ) |
| 244 | + score = 0.9 |
| 245 | + |
| 246 | + return safe_finalize( |
| 247 | + name=spec.name, |
| 248 | + kind=spec.kind, |
| 249 | + verdict=verdict, |
| 250 | + score=score, |
| 251 | + raw=frac_undertrained, |
| 252 | + z_score=None, |
| 253 | + evidence={ |
| 254 | + "global_step": snap.global_step, |
| 255 | + "epoch": snap.epoch, |
| 256 | + "num_params": len(snap.per_param), |
| 257 | + "num_layers": layer_count, |
| 258 | + "params_per_layer": params_per_layer, |
| 259 | + "global_mean_exp_avg_sq": global_mean, |
| 260 | + "frac_layers_undertrained": frac_undertrained, |
| 261 | + "num_layers_undertrained": len(per_layer_undertrained), |
| 262 | + "worst_layers": worst_layers, |
| 263 | + "num_nonfinite_exp_avg_sq": nan_or_inf_count, |
| 264 | + "best_val_loss": snap.best_val_loss if math.isfinite(snap.best_val_loss) else None, |
| 265 | + "use_qlora": snap.use_qlora, |
| 266 | + }, |
| 267 | + message=message, |
| 268 | + ) |