@@ -0,0 +1,486 @@ |
| | 1 | +"""Training-drift — pre-run probe that reads dlm's per-step loss curve. |
| | 2 | + |
| | 3 | +Sister probe to :mod:`gradient_ghost`. Where ``gradient_ghost`` reads |
| | 4 | +the optimizer state at the end of training, ``training_drift`` reads |
| | 5 | +the loss curve *during* training: the rich signal about *how* the |
| | 6 | +adapter trained, not just what it produced. |
| | 7 | + |
| | 8 | +Four metrics extracted from the curve: |
| | 9 | + |
| | 10 | +- ``final_loss`` — last recorded step's loss. |
| | 11 | +- ``convergence_ratio`` — ``final_loss / initial_loss``. Lower is |
| | 12 | + better; a healthy adapter cuts loss by half or more. |
| | 13 | +- ``smoothness`` — ``1 - (var(Δloss) / var(loss))``. Values near 1 |
| | 14 | + mean the curve descends smoothly; near 0 means each step's |
| | 15 | + change-in-loss is comparable to the loss range itself (spiky). |
| | 16 | +- ``instability_events`` — count of steps where |
| | 17 | + ``|Δloss| > 3 · rolling_std``. Spikes that survive the rolling |
| | 18 | + window are real — they correlate with silent adapter degradation. |
| | 19 | + |
| | 20 | +Verdict: PASS when all three of (smoothness ≥ 0.7, instability_events |
| | 21 | +== 0, convergence_ratio ≤ 0.7); else WARN. The thresholds are |
| | 22 | +hand-tuned defaults; spec fields override them. |
| | 23 | + |
| | 24 | +## Why no null calibration |
| | 25 | + |
| | 26 | +Mirrors ``prompt_collapse`` and ``multi_turn_coherence_decay``: a null |
| | 27 | +adapter doesn't *train* — it has no loss curve. The null distribution |
| | 28 | +of "smoothness on a noise adapter" is undefined. Fixed-threshold |
| | 29 | +verdicts are the published path; users override per-spec. |
| | 30 | + |
| | 31 | +## Log-format note |
| | 32 | + |
| | 33 | +dlm writes one JSONL per run at |
| | 34 | +``<store_path>/logs/train-NNNNNN-YYYYMMDDTHHMMSS.jsonl``. Each line is |
| | 35 | +``{"type": "<banner|step|run_complete|...>", ...}``. This probe |
| | 36 | +filters for ``type == "step"`` records and reads ``step`` + ``loss``. |
| | 37 | +The sibling ``*.summary.json`` has run aggregates; we don't consume |
| | 38 | +it here — the curve is richer. |
| | 39 | +""" |
| | 40 | + |
| | 41 | +from __future__ import annotations |
| | 42 | + |
| | 43 | +import json |
| | 44 | +import math |
| | 45 | +from pathlib import Path |
| | 46 | +from typing import ClassVar, Literal |
| | 47 | + |
| | 48 | +import numpy as np |
| | 49 | +from pydantic import Field |
| | 50 | + |
| | 51 | +from dlm_sway.core.errors import SwayError |
| | 52 | +from dlm_sway.core.result import ProbeResult, Verdict, safe_finalize |
| | 53 | +from dlm_sway.probes.base import Probe, ProbeSpec, RunContext |
| | 54 | + |
| | 55 | + |
| | 56 | +class TrainingDriftError(SwayError): |
| | 57 | + """Raised when a training log file is structurally unparseable. |
| | 58 | + |
| | 59 | + Distinct from :class:`MissingTrainingStateError`-style absences: |
| | 60 | + a missing logs directory is a SKIP (no .dlm context, or a |
| | 61 | + pre-training adapter), but a JSONL with corrupted lines or the |
| | 62 | + wrong shape is an ERROR — the user has training data but we |
| | 63 | + can't read it, which deserves a noisy failure. |
| | 64 | + """ |
| | 65 | + |
| | 66 | + |
| | 67 | +class TrainingDriftSpec(ProbeSpec): |
| | 68 | + """Spec for ``kind: training_drift``.""" |
| | 69 | + |
| | 70 | + kind: Literal["training_drift"] = "training_drift" |
| | 71 | + store_path: str | None = None |
| | 72 | + """Path to a dlm store root (the directory containing |
| | 73 | + ``logs/`` and ``adapter/``). When ``None`` the probe SKIPs — |
| | 74 | + typically populated by the .dlm autogen path which knows the |
| | 75 | + store from the dlm_id resolution.""" |
| | 76 | + min_steps: int = Field(default=10, ge=2) |
| | 77 | + """Skip when the curve has fewer steps than this. Short runs |
| | 78 | + (3-step early-stops, smoke tests) produce undefined fits; |
| | 79 | + surfacing those as PASS/FAIL would be misleading.""" |
| | 80 | + rolling_window: int = Field(default=10, ge=2) |
| | 81 | + """Window for the rolling std used in spike detection. Larger |
| | 82 | + windows mean tighter spike detection at the cost of missing |
| | 83 | + rapid oscillations. 10 is a balance for typical 100–1000 step |
| | 84 | + runs.""" |
| | 85 | + spike_sigma: float = Field(default=3.0, gt=0.0) |
| | 86 | + """A step is an instability event when ``|Δloss|`` exceeds this |
| | 87 | + many rolling standard deviations. 3σ is the conventional |
| | 88 | + "outlier" boundary; lower for more sensitivity, higher for less |
| | 89 | + noise tolerance.""" |
| | 90 | + assert_smoothness_gte: float = Field(default=0.7, ge=0.0, le=1.0) |
| | 91 | + """Minimum smoothness for PASS.""" |
| | 92 | + assert_convergence_ratio_lte: float = Field(default=0.7, gt=0.0) |
| | 93 | + """Maximum ``final_loss / initial_loss`` for PASS. A "well- |
| | 94 | + trained" run typically halves loss; permissive default tolerates |
| | 95 | + noisier data.""" |
| | 96 | + assert_instability_events_lte: int = Field(default=0, ge=0) |
| | 97 | + """Maximum allowed spike count. Default 0 — any spike → WARN.""" |
| | 98 | + |
| | 99 | + |
| | 100 | +class TrainingDriftProbe(Probe): |
| | 101 | + """The "did this adapter train smoothly?" pre-run probe.""" |
| | 102 | + |
| | 103 | + kind = "training_drift" |
| | 104 | + spec_cls = TrainingDriftSpec |
| | 105 | + category = "calibration" |
| | 106 | + needs_backend: ClassVar[bool] = False |
| | 107 | + # Pre-run probe: no model load, runs in <100ms, ideal for |
| | 108 | + # ``sway check`` first-pass before any heavy probe fires. |
| | 109 | + # Mirrors gradient_ghost's posture. |
| | 110 | + |
| | 111 | + def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult: |
| | 112 | + del ctx # No backend / sections / null_stats consumed. |
| | 113 | + assert isinstance(spec, TrainingDriftSpec) |
| | 114 | + |
| | 115 | + if spec.store_path is None: |
| | 116 | + return ProbeResult( |
| | 117 | + name=spec.name, |
| | 118 | + kind=spec.kind, |
| | 119 | + verdict=Verdict.SKIP, |
| | 120 | + score=None, |
| | 121 | + message="no store_path provided (no .dlm context)", |
| | 122 | + ) |
| | 123 | + |
| | 124 | + store_path = Path(spec.store_path).expanduser().resolve() |
| | 125 | + logs_dir = store_path / "logs" |
| | 126 | + if not logs_dir.is_dir(): |
| | 127 | + return ProbeResult( |
| | 128 | + name=spec.name, |
| | 129 | + kind=spec.kind, |
| | 130 | + verdict=Verdict.SKIP, |
| | 131 | + score=None, |
| | 132 | + message=f"no logs/ directory under {store_path}", |
| | 133 | + ) |
| | 134 | + |
| | 135 | + log_paths = sorted(logs_dir.glob("train-*.jsonl")) |
| | 136 | + if not log_paths: |
| | 137 | + return ProbeResult( |
| | 138 | + name=spec.name, |
| | 139 | + kind=spec.kind, |
| | 140 | + verdict=Verdict.SKIP, |
| | 141 | + score=None, |
| | 142 | + message=f"no train-*.jsonl files under {logs_dir}", |
| | 143 | + ) |
| | 144 | + |
| | 145 | + # Concatenate steps across all runs in chronological order. |
| | 146 | + # Resumed runs may produce duplicate step numbers — dedupe by |
| | 147 | + # keeping the latest occurrence (the most recent run's value |
| | 148 | + # for that step). Sorted glob already gives chronological |
| | 149 | + # order via the timestamp suffix. |
| | 150 | + try: |
| | 151 | + steps_by_idx = _collect_steps(log_paths) |
| | 152 | + except TrainingDriftError as exc: |
| | 153 | + return ProbeResult( |
| | 154 | + name=spec.name, |
| | 155 | + kind=spec.kind, |
| | 156 | + verdict=Verdict.ERROR, |
| | 157 | + score=None, |
| | 158 | + message=str(exc), |
| | 159 | + evidence={"log_paths": [str(p) for p in log_paths]}, |
| | 160 | + ) |
| | 161 | + |
| | 162 | + if len(steps_by_idx) < spec.min_steps: |
| | 163 | + return ProbeResult( |
| | 164 | + name=spec.name, |
| | 165 | + kind=spec.kind, |
| | 166 | + verdict=Verdict.SKIP, |
| | 167 | + score=None, |
| | 168 | + message=( |
| | 169 | + f"only {len(steps_by_idx)} step records " |
| | 170 | + f"(< min_steps={spec.min_steps}); " |
| | 171 | + f"curve too short to fit reliably" |
| | 172 | + ), |
| | 173 | + evidence={"num_steps": len(steps_by_idx)}, |
| | 174 | + ) |
| | 175 | + |
| | 176 | + # Sorted (step, loss) pairs. |
| | 177 | + ordered = sorted(steps_by_idx.items()) |
| | 178 | + steps = np.asarray([s for s, _ in ordered], dtype=np.int64) |
| | 179 | + losses = np.asarray([loss for _, loss in ordered], dtype=np.float64) |
| | 180 | + |
| | 181 | + # All four metrics. Each can fail gracefully (NaN-only loss |
| | 182 | + # column, all-equal losses, etc.) and the verdict respects the |
| | 183 | + # safe_finalize critical-field guard for ``raw``. |
| | 184 | + metrics = _compute_metrics( |
| | 185 | + losses, rolling_window=spec.rolling_window, spike_sigma=spec.spike_sigma |
| | 186 | + ) |
| | 187 | + |
| | 188 | + verdict, score, message = _verdict_from_metrics(metrics, spec) |
| | 189 | + |
| | 190 | + # Bound the curve we ship in evidence: a 10k-step run shouldn't |
| | 191 | + # explode the JSON report. Downsample uniformly to the cap. |
| | 192 | + curve = _downsampled_curve(steps, losses, cap=512) |
| | 193 | + |
| | 194 | + return safe_finalize( |
| | 195 | + name=spec.name, |
| | 196 | + kind=spec.kind, |
| | 197 | + verdict=verdict, |
| | 198 | + score=score, |
| | 199 | + raw=metrics["smoothness"], |
| | 200 | + evidence={ |
| | 201 | + "final_loss": metrics["final_loss"], |
| | 202 | + "initial_loss": metrics["initial_loss"], |
| | 203 | + "convergence_ratio": metrics["convergence_ratio"], |
| | 204 | + "smoothness": metrics["smoothness"], |
| | 205 | + "instability_events": metrics["instability_events"], |
| | 206 | + "num_steps": len(losses), |
| | 207 | + "num_log_files": len(log_paths), |
| | 208 | + "curve_sampled": curve, |
| | 209 | + "thresholds": { |
| | 210 | + "smoothness_gte": spec.assert_smoothness_gte, |
| | 211 | + "convergence_ratio_lte": spec.assert_convergence_ratio_lte, |
| | 212 | + "instability_events_lte": spec.assert_instability_events_lte, |
| | 213 | + }, |
| | 214 | + "weight": spec.weight, |
| | 215 | + }, |
| | 216 | + message=message, |
| | 217 | + critical_fields=(), |
| | 218 | + # ``raw`` (smoothness) is bounded [0, 1] in normal cases; |
| | 219 | + # the metric helper returns NaN only on degenerate inputs |
| | 220 | + # which we already surfaced via the SKIP path. Critical- |
| | 221 | + # field guard would null-out the score on a NaN smoothness |
| | 222 | + # which is more confusing than helpful here. |
| | 223 | + ) |
| | 224 | + |
| | 225 | + |
| | 226 | +# --------------------------------------------------------------------------- |
| | 227 | +# JSONL parsing |
| | 228 | +# --------------------------------------------------------------------------- |
| | 229 | + |
| | 230 | + |
| | 231 | +def _collect_steps(log_paths: list[Path]) -> dict[int, float]: |
| | 232 | + """Parse every JSONL, dedupe-by-step, return ``{step: loss}``. |
| | 233 | + |
| | 234 | + Resumed runs append to a fresh JSONL but share step numbers with |
| | 235 | + the prior run. Dedupe-by-keep-latest matches dlm's own ``dlm |
| | 236 | + metrics`` semantics: the most recent run for a given step wins. |
| | 237 | + """ |
| | 238 | + out: dict[int, float] = {} |
| | 239 | + for path in log_paths: |
| | 240 | + try: |
| | 241 | + with path.open("r", encoding="utf-8") as f: |
| | 242 | + for line_no, raw in enumerate(f, start=1): |
| | 243 | + raw = raw.strip() |
| | 244 | + if not raw: |
| | 245 | + continue |
| | 246 | + try: |
| | 247 | + rec = json.loads(raw) |
| | 248 | + except json.JSONDecodeError as exc: |
| | 249 | + # A trailing partial line from a crashed |
| | 250 | + # trainer is the typical cause. Skip it but |
| | 251 | + # surface as ERROR if EVERY line is broken |
| | 252 | + # (caller checks ``out`` emptiness). |
| | 253 | + if line_no == 1: |
| | 254 | + raise TrainingDriftError( |
| | 255 | + f"first line of {path.name} is not valid JSON: {exc}" |
| | 256 | + ) from exc |
| | 257 | + continue |
| | 258 | + if not isinstance(rec, dict): |
| | 259 | + continue |
| | 260 | + if rec.get("type") != "step": |
| | 261 | + continue |
| | 262 | + try: |
| | 263 | + step = int(rec["step"]) |
| | 264 | + loss = float(rec["loss"]) |
| | 265 | + except (KeyError, TypeError, ValueError): |
| | 266 | + continue |
| | 267 | + if not math.isfinite(loss): |
| | 268 | + # NaN loss is a real signal — record it as |
| | 269 | + # +inf so the spike detector flags the |
| | 270 | + # instability without crashing on np.log. |
| | 271 | + loss = math.inf |
| | 272 | + out[step] = loss |
| | 273 | + except OSError as exc: |
| | 274 | + raise TrainingDriftError(f"failed to read {path}: {exc}") from exc |
| | 275 | + return out |
| | 276 | + |
| | 277 | + |
| | 278 | +# --------------------------------------------------------------------------- |
| | 279 | +# Metric computation |
| | 280 | +# --------------------------------------------------------------------------- |
| | 281 | + |
| | 282 | + |
| | 283 | +def _compute_metrics( |
| | 284 | + losses: np.ndarray, |
| | 285 | + *, |
| | 286 | + rolling_window: int, |
| | 287 | + spike_sigma: float, |
| | 288 | +) -> dict[str, float]: |
| | 289 | + """Compute the four headline metrics from the loss array. |
| | 290 | + |
| | 291 | + All metrics are robust to NaN/inf: non-finite step losses are |
| | 292 | + replaced with the most-recent finite value before fitting (so a |
| | 293 | + single bad batch doesn't poison the curve), but their *positions* |
| | 294 | + are still counted as instability events. |
| | 295 | + """ |
| | 296 | + losses = losses.astype(np.float64, copy=True) |
| | 297 | + instability_from_nan = int(np.sum(~np.isfinite(losses))) |
| | 298 | + |
| | 299 | + if instability_from_nan > 0: |
| | 300 | + # Forward-fill non-finite values from the previous finite |
| | 301 | + # entry so downstream stats don't blow up. The first entry |
| | 302 | + # MUST be finite (training records the first batch's loss |
| | 303 | + # before anything could go wrong); guard anyway. |
| | 304 | + finite_mask = np.isfinite(losses) |
| | 305 | + if not finite_mask.any(): |
| | 306 | + # Pathological: every step recorded NaN. Return all-NaN |
| | 307 | + # metrics so the verdict path can surface the failure. |
| | 308 | + return { |
| | 309 | + "initial_loss": float("nan"), |
| | 310 | + "final_loss": float("nan"), |
| | 311 | + "convergence_ratio": float("nan"), |
| | 312 | + "smoothness": 0.0, |
| | 313 | + "instability_events": float(len(losses)), |
| | 314 | + } |
| | 315 | + last_good = float(losses[finite_mask][0]) |
| | 316 | + for i, v in enumerate(losses): |
| | 317 | + if not math.isfinite(v): |
| | 318 | + losses[i] = last_good |
| | 319 | + else: |
| | 320 | + last_good = float(v) |
| | 321 | + |
| | 322 | + initial_loss = float(losses[0]) |
| | 323 | + final_loss = float(losses[-1]) |
| | 324 | + convergence_ratio = float(final_loss / initial_loss) if initial_loss != 0.0 else float("inf") |
| | 325 | + |
| | 326 | + deltas = np.diff(losses) |
| | 327 | + var_loss = float(losses.var()) |
| | 328 | + var_delta = float(deltas.var()) if deltas.size > 0 else 0.0 |
| | 329 | + if var_loss > 0.0: # noqa: SIM108 — branch comments are load-bearing |
| | 330 | + # Clip into [0, 1]: a curve where var(Δloss) > var(loss) |
| | 331 | + # implies the per-step jitter dominates the overall sweep — |
| | 332 | + # treat as fully-spiky (smoothness=0) rather than emit a |
| | 333 | + # negative number. |
| | 334 | + smoothness = max(0.0, 1.0 - var_delta / var_loss) |
| | 335 | + else: |
| | 336 | + # Identical losses across every step: the run never |
| | 337 | + # progressed. Smoothness is formally 1.0 (perfectly flat), |
| | 338 | + # but that's misleading — surface as 0.0 so the verdict path |
| | 339 | + # can flag it. |
| | 340 | + smoothness = 0.0 |
| | 341 | + |
| | 342 | + instability_events = _count_spikes(deltas, window=rolling_window, sigma=spike_sigma) |
| | 343 | + instability_events += instability_from_nan |
| | 344 | + |
| | 345 | + return { |
| | 346 | + "initial_loss": initial_loss, |
| | 347 | + "final_loss": final_loss, |
| | 348 | + "convergence_ratio": convergence_ratio, |
| | 349 | + "smoothness": smoothness, |
| | 350 | + "instability_events": float(instability_events), |
| | 351 | + } |
| | 352 | + |
| | 353 | + |
| | 354 | +def _count_spikes(deltas: np.ndarray, *, window: int, sigma: float) -> int: |
| | 355 | + """Count loss-increase events that exceed a robust noise threshold. |
| | 356 | + |
| | 357 | + The naive ``|Δloss| > sigma · rolling_std`` heuristic is broken for |
| | 358 | + exponential decay: deltas span orders of magnitude across training, |
| | 359 | + so within-window std stays tiny while absolute deltas are large — |
| | 360 | + every step trips the threshold. The semantically correct |
| | 361 | + "instability event" is a loss *increase*, not a faster-than-typical |
| | 362 | + decrease. |
| | 363 | + |
| | 364 | + Heuristic: |
| | 365 | + |
| | 366 | + 1. Restrict to positive deltas (``delta > 0`` ⇒ loss went up). Loss |
| | 367 | + going down faster than usual isn't an instability — it's the |
| | 368 | + happy path. |
| | 369 | + 2. For each positive delta, compare to the median absolute delta |
| | 370 | + in a centered window. Flag when ``delta > sigma · MAD``, where |
| | 371 | + MAD is the median absolute deviation (robust to outliers). |
| | 372 | + 3. Short curves fall back to global MAD; constant-loss curves |
| | 373 | + (every delta zero) report 0 spikes. |
| | 374 | + |
| | 375 | + The ``sigma`` parameter retains its semantic meaning ("how many |
| | 376 | + typical deltas does this exceed") but operates against the |
| | 377 | + median-absolute-delta scale rather than std-of-delta. |
| | 378 | + """ |
| | 379 | + if deltas.size == 0: |
| | 380 | + return 0 |
| | 381 | + |
| | 382 | + # Median absolute delta — the "typical movement scale" we judge |
| | 383 | + # spikes against. Median is robust to a few outliers; if the |
| | 384 | + # whole curve is flat, MAD is 0 and we report no spikes. |
| | 385 | + abs_deltas = np.abs(deltas) |
| | 386 | + |
| | 387 | + if deltas.size < window: |
| | 388 | + baseline = float(np.median(abs_deltas)) |
| | 389 | + if baseline == 0.0: |
| | 390 | + return 0 |
| | 391 | + return int(np.sum((deltas > 0.0) & (deltas > sigma * baseline))) |
| | 392 | + |
| | 393 | + spikes = 0 |
| | 394 | + half = window // 2 |
| | 395 | + for i in range(deltas.size): |
| | 396 | + if deltas[i] <= 0.0: |
| | 397 | + continue # Loss went down — not an instability event. |
| | 398 | + lo = max(0, i - half) |
| | 399 | + hi = min(deltas.size, i + half + 1) |
| | 400 | + window_slice = abs_deltas[lo:hi] |
| | 401 | + # Exclude the point itself so a real outlier doesn't inflate |
| | 402 | + # the baseline it's measured against. |
| | 403 | + if window_slice.size > 1: |
| | 404 | + mask = np.ones(window_slice.size, dtype=bool) |
| | 405 | + mask[i - lo] = False |
| | 406 | + window_slice = window_slice[mask] |
| | 407 | + baseline = float(np.median(window_slice)) |
| | 408 | + if baseline == 0.0: |
| | 409 | + # Surrounding deltas are all zero — any nonzero positive |
| | 410 | + # delta is an instability event by construction. |
| | 411 | + spikes += 1 |
| | 412 | + continue |
| | 413 | + if float(deltas[i]) > sigma * baseline: |
| | 414 | + spikes += 1 |
| | 415 | + return spikes |
| | 416 | + |
| | 417 | + |
| | 418 | +# --------------------------------------------------------------------------- |
| | 419 | +# Verdict mapping |
| | 420 | +# --------------------------------------------------------------------------- |
| | 421 | + |
| | 422 | + |
| | 423 | +def _verdict_from_metrics( |
| | 424 | + metrics: dict[str, float], spec: TrainingDriftSpec |
| | 425 | +) -> tuple[Verdict, float, str]: |
| | 426 | + """Map the four metrics to a (verdict, score, message).""" |
| | 427 | + smooth = metrics["smoothness"] |
| | 428 | + conv = metrics["convergence_ratio"] |
| | 429 | + instability = int(metrics["instability_events"]) |
| | 430 | + |
| | 431 | + smoothness_pass = smooth >= spec.assert_smoothness_gte |
| | 432 | + convergence_pass = conv <= spec.assert_convergence_ratio_lte |
| | 433 | + instability_pass = instability <= spec.assert_instability_events_lte |
| | 434 | + |
| | 435 | + failures: list[str] = [] |
| | 436 | + if not smoothness_pass: |
| | 437 | + failures.append(f"smoothness={smooth:.2f} < {spec.assert_smoothness_gte}") |
| | 438 | + if not convergence_pass: |
| | 439 | + failures.append(f"convergence_ratio={conv:.2f} > {spec.assert_convergence_ratio_lte}") |
| | 440 | + if not instability_pass: |
| | 441 | + failures.append(f"instability_events={instability} > {spec.assert_instability_events_lte}") |
| | 442 | + |
| | 443 | + headline = ( |
| | 444 | + f"smoothness={smooth:.2f}, convergence_ratio={conv:.2f}, " |
| | 445 | + f"instability_events={instability}, final_loss={metrics['final_loss']:.3f}" |
| | 446 | + ) |
| | 447 | + |
| | 448 | + if not failures: |
| | 449 | + # All three thresholds clear → PASS with a normalized score |
| | 450 | + # blending the three signals (smoothness contributes most; |
| | 451 | + # convergence and instability are guard rails). |
| | 452 | + score = float(min(1.0, max(0.0, smooth))) |
| | 453 | + return Verdict.PASS, score, headline |
| | 454 | + |
| | 455 | + # Score: a continuous blend of how far we are from each threshold, |
| | 456 | + # so the report can rank "borderline warn" against "actively bad." |
| | 457 | + # Doesn't influence the verdict — that's already FAIL/WARN. |
| | 458 | + score = float(min(1.0, max(0.0, smooth * 0.5))) |
| | 459 | + return Verdict.WARN, score, f"{headline}; warnings: {'; '.join(failures)}" |
| | 460 | + |
| | 461 | + |
| | 462 | +# --------------------------------------------------------------------------- |
| | 463 | +# Curve downsampling for evidence |
| | 464 | +# --------------------------------------------------------------------------- |
| | 465 | + |
| | 466 | + |
| | 467 | +def _downsampled_curve( |
| | 468 | + steps: np.ndarray, losses: np.ndarray, *, cap: int |
| | 469 | +) -> list[tuple[int, float]]: |
| | 470 | + """Uniform-stride downsample so a 10k-step run still fits in the JSON report. |
| | 471 | + |
| | 472 | + Always preserves the first and last point so initial/final loss |
| | 473 | + are visible regardless of cap. For curves shorter than the cap, |
| | 474 | + returns the full series unchanged. |
| | 475 | + """ |
| | 476 | + n = int(len(losses)) |
| | 477 | + if n <= cap: |
| | 478 | + return [(int(s), float(loss)) for s, loss in zip(steps, losses, strict=True)] |
| | 479 | + # Stride to keep at most ``cap`` points: stride = ceil(n / cap). |
| | 480 | + # The +1 accounts for always-appending the final point even when |
| | 481 | + # ``stride * (cap - 1) < n - 1``. |
| | 482 | + stride = max(1, (n + cap - 1) // cap) |
| | 483 | + idx = list(range(0, n, stride)) |
| | 484 | + if idx[-1] != n - 1: |
| | 485 | + idx.append(n - 1) |
| | 486 | + return [(int(steps[i]), float(losses[i])) for i in idx] |