@@ -0,0 +1,486 @@ |
| 1 | +"""Training-drift — pre-run probe that reads dlm's per-step loss curve. |
| 2 | + |
| 3 | +Sister probe to :mod:`gradient_ghost`. Where ``gradient_ghost`` reads |
| 4 | +the optimizer state at the end of training, ``training_drift`` reads |
| 5 | +the loss curve *during* training: the rich signal about *how* the |
| 6 | +adapter trained, not just what it produced. |
| 7 | + |
| 8 | +Four metrics extracted from the curve: |
| 9 | + |
| 10 | +- ``final_loss`` — last recorded step's loss. |
| 11 | +- ``convergence_ratio`` — ``final_loss / initial_loss``. Lower is |
| 12 | + better; a healthy adapter cuts loss by half or more. |
| 13 | +- ``smoothness`` — ``1 - (var(Δloss) / var(loss))``. Values near 1 |
| 14 | + mean the curve descends smoothly; near 0 means each step's |
| 15 | + change-in-loss is comparable to the loss range itself (spiky). |
| 16 | +- ``instability_events`` — count of steps where |
| 17 | + ``|Δloss| > 3 · rolling_std``. Spikes that survive the rolling |
| 18 | + window are real — they correlate with silent adapter degradation. |
| 19 | + |
| 20 | +Verdict: PASS when all three of (smoothness ≥ 0.7, instability_events |
| 21 | +== 0, convergence_ratio ≤ 0.7); else WARN. The thresholds are |
| 22 | +hand-tuned defaults; spec fields override them. |
| 23 | + |
| 24 | +## Why no null calibration |
| 25 | + |
| 26 | +Mirrors ``prompt_collapse`` and ``multi_turn_coherence_decay``: a null |
| 27 | +adapter doesn't *train* — it has no loss curve. The null distribution |
| 28 | +of "smoothness on a noise adapter" is undefined. Fixed-threshold |
| 29 | +verdicts are the published path; users override per-spec. |
| 30 | + |
| 31 | +## Log-format note |
| 32 | + |
| 33 | +dlm writes one JSONL per run at |
| 34 | +``<store_path>/logs/train-NNNNNN-YYYYMMDDTHHMMSS.jsonl``. Each line is |
| 35 | +``{"type": "<banner|step|run_complete|...>", ...}``. This probe |
| 36 | +filters for ``type == "step"`` records and reads ``step`` + ``loss``. |
| 37 | +The sibling ``*.summary.json`` has run aggregates; we don't consume |
| 38 | +it here — the curve is richer. |
| 39 | +""" |
| 40 | + |
| 41 | +from __future__ import annotations |
| 42 | + |
| 43 | +import json |
| 44 | +import math |
| 45 | +from pathlib import Path |
| 46 | +from typing import ClassVar, Literal |
| 47 | + |
| 48 | +import numpy as np |
| 49 | +from pydantic import Field |
| 50 | + |
| 51 | +from dlm_sway.core.errors import SwayError |
| 52 | +from dlm_sway.core.result import ProbeResult, Verdict, safe_finalize |
| 53 | +from dlm_sway.probes.base import Probe, ProbeSpec, RunContext |
| 54 | + |
| 55 | + |
| 56 | +class TrainingDriftError(SwayError): |
| 57 | + """Raised when a training log file is structurally unparseable. |
| 58 | + |
| 59 | + Distinct from :class:`MissingTrainingStateError`-style absences: |
| 60 | + a missing logs directory is a SKIP (no .dlm context, or a |
| 61 | + pre-training adapter), but a JSONL with corrupted lines or the |
| 62 | + wrong shape is an ERROR — the user has training data but we |
| 63 | + can't read it, which deserves a noisy failure. |
| 64 | + """ |
| 65 | + |
| 66 | + |
| 67 | +class TrainingDriftSpec(ProbeSpec): |
| 68 | + """Spec for ``kind: training_drift``.""" |
| 69 | + |
| 70 | + kind: Literal["training_drift"] = "training_drift" |
| 71 | + store_path: str | None = None |
| 72 | + """Path to a dlm store root (the directory containing |
| 73 | + ``logs/`` and ``adapter/``). When ``None`` the probe SKIPs — |
| 74 | + typically populated by the .dlm autogen path which knows the |
| 75 | + store from the dlm_id resolution.""" |
| 76 | + min_steps: int = Field(default=10, ge=2) |
| 77 | + """Skip when the curve has fewer steps than this. Short runs |
| 78 | + (3-step early-stops, smoke tests) produce undefined fits; |
| 79 | + surfacing those as PASS/FAIL would be misleading.""" |
| 80 | + rolling_window: int = Field(default=10, ge=2) |
| 81 | + """Window for the rolling std used in spike detection. Larger |
| 82 | + windows mean tighter spike detection at the cost of missing |
| 83 | + rapid oscillations. 10 is a balance for typical 100–1000 step |
| 84 | + runs.""" |
| 85 | + spike_sigma: float = Field(default=3.0, gt=0.0) |
| 86 | + """A step is an instability event when ``|Δloss|`` exceeds this |
| 87 | + many rolling standard deviations. 3σ is the conventional |
| 88 | + "outlier" boundary; lower for more sensitivity, higher for less |
| 89 | + noise tolerance.""" |
| 90 | + assert_smoothness_gte: float = Field(default=0.7, ge=0.0, le=1.0) |
| 91 | + """Minimum smoothness for PASS.""" |
| 92 | + assert_convergence_ratio_lte: float = Field(default=0.7, gt=0.0) |
| 93 | + """Maximum ``final_loss / initial_loss`` for PASS. A "well- |
| 94 | + trained" run typically halves loss; permissive default tolerates |
| 95 | + noisier data.""" |
| 96 | + assert_instability_events_lte: int = Field(default=0, ge=0) |
| 97 | + """Maximum allowed spike count. Default 0 — any spike → WARN.""" |
| 98 | + |
| 99 | + |
| 100 | +class TrainingDriftProbe(Probe): |
| 101 | + """The "did this adapter train smoothly?" pre-run probe.""" |
| 102 | + |
| 103 | + kind = "training_drift" |
| 104 | + spec_cls = TrainingDriftSpec |
| 105 | + category = "calibration" |
| 106 | + needs_backend: ClassVar[bool] = False |
| 107 | + # Pre-run probe: no model load, runs in <100ms, ideal for |
| 108 | + # ``sway check`` first-pass before any heavy probe fires. |
| 109 | + # Mirrors gradient_ghost's posture. |
| 110 | + |
| 111 | + def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult: |
| 112 | + del ctx # No backend / sections / null_stats consumed. |
| 113 | + assert isinstance(spec, TrainingDriftSpec) |
| 114 | + |
| 115 | + if spec.store_path is None: |
| 116 | + return ProbeResult( |
| 117 | + name=spec.name, |
| 118 | + kind=spec.kind, |
| 119 | + verdict=Verdict.SKIP, |
| 120 | + score=None, |
| 121 | + message="no store_path provided (no .dlm context)", |
| 122 | + ) |
| 123 | + |
| 124 | + store_path = Path(spec.store_path).expanduser().resolve() |
| 125 | + logs_dir = store_path / "logs" |
| 126 | + if not logs_dir.is_dir(): |
| 127 | + return ProbeResult( |
| 128 | + name=spec.name, |
| 129 | + kind=spec.kind, |
| 130 | + verdict=Verdict.SKIP, |
| 131 | + score=None, |
| 132 | + message=f"no logs/ directory under {store_path}", |
| 133 | + ) |
| 134 | + |
| 135 | + log_paths = sorted(logs_dir.glob("train-*.jsonl")) |
| 136 | + if not log_paths: |
| 137 | + return ProbeResult( |
| 138 | + name=spec.name, |
| 139 | + kind=spec.kind, |
| 140 | + verdict=Verdict.SKIP, |
| 141 | + score=None, |
| 142 | + message=f"no train-*.jsonl files under {logs_dir}", |
| 143 | + ) |
| 144 | + |
| 145 | + # Concatenate steps across all runs in chronological order. |
| 146 | + # Resumed runs may produce duplicate step numbers — dedupe by |
| 147 | + # keeping the latest occurrence (the most recent run's value |
| 148 | + # for that step). Sorted glob already gives chronological |
| 149 | + # order via the timestamp suffix. |
| 150 | + try: |
| 151 | + steps_by_idx = _collect_steps(log_paths) |
| 152 | + except TrainingDriftError as exc: |
| 153 | + return ProbeResult( |
| 154 | + name=spec.name, |
| 155 | + kind=spec.kind, |
| 156 | + verdict=Verdict.ERROR, |
| 157 | + score=None, |
| 158 | + message=str(exc), |
| 159 | + evidence={"log_paths": [str(p) for p in log_paths]}, |
| 160 | + ) |
| 161 | + |
| 162 | + if len(steps_by_idx) < spec.min_steps: |
| 163 | + return ProbeResult( |
| 164 | + name=spec.name, |
| 165 | + kind=spec.kind, |
| 166 | + verdict=Verdict.SKIP, |
| 167 | + score=None, |
| 168 | + message=( |
| 169 | + f"only {len(steps_by_idx)} step records " |
| 170 | + f"(< min_steps={spec.min_steps}); " |
| 171 | + f"curve too short to fit reliably" |
| 172 | + ), |
| 173 | + evidence={"num_steps": len(steps_by_idx)}, |
| 174 | + ) |
| 175 | + |
| 176 | + # Sorted (step, loss) pairs. |
| 177 | + ordered = sorted(steps_by_idx.items()) |
| 178 | + steps = np.asarray([s for s, _ in ordered], dtype=np.int64) |
| 179 | + losses = np.asarray([loss for _, loss in ordered], dtype=np.float64) |
| 180 | + |
| 181 | + # All four metrics. Each can fail gracefully (NaN-only loss |
| 182 | + # column, all-equal losses, etc.) and the verdict respects the |
| 183 | + # safe_finalize critical-field guard for ``raw``. |
| 184 | + metrics = _compute_metrics( |
| 185 | + losses, rolling_window=spec.rolling_window, spike_sigma=spec.spike_sigma |
| 186 | + ) |
| 187 | + |
| 188 | + verdict, score, message = _verdict_from_metrics(metrics, spec) |
| 189 | + |
| 190 | + # Bound the curve we ship in evidence: a 10k-step run shouldn't |
| 191 | + # explode the JSON report. Downsample uniformly to the cap. |
| 192 | + curve = _downsampled_curve(steps, losses, cap=512) |
| 193 | + |
| 194 | + return safe_finalize( |
| 195 | + name=spec.name, |
| 196 | + kind=spec.kind, |
| 197 | + verdict=verdict, |
| 198 | + score=score, |
| 199 | + raw=metrics["smoothness"], |
| 200 | + evidence={ |
| 201 | + "final_loss": metrics["final_loss"], |
| 202 | + "initial_loss": metrics["initial_loss"], |
| 203 | + "convergence_ratio": metrics["convergence_ratio"], |
| 204 | + "smoothness": metrics["smoothness"], |
| 205 | + "instability_events": metrics["instability_events"], |
| 206 | + "num_steps": len(losses), |
| 207 | + "num_log_files": len(log_paths), |
| 208 | + "curve_sampled": curve, |
| 209 | + "thresholds": { |
| 210 | + "smoothness_gte": spec.assert_smoothness_gte, |
| 211 | + "convergence_ratio_lte": spec.assert_convergence_ratio_lte, |
| 212 | + "instability_events_lte": spec.assert_instability_events_lte, |
| 213 | + }, |
| 214 | + "weight": spec.weight, |
| 215 | + }, |
| 216 | + message=message, |
| 217 | + critical_fields=(), |
| 218 | + # ``raw`` (smoothness) is bounded [0, 1] in normal cases; |
| 219 | + # the metric helper returns NaN only on degenerate inputs |
| 220 | + # which we already surfaced via the SKIP path. Critical- |
| 221 | + # field guard would null-out the score on a NaN smoothness |
| 222 | + # which is more confusing than helpful here. |
| 223 | + ) |
| 224 | + |
| 225 | + |
| 226 | +# --------------------------------------------------------------------------- |
| 227 | +# JSONL parsing |
| 228 | +# --------------------------------------------------------------------------- |
| 229 | + |
| 230 | + |
| 231 | +def _collect_steps(log_paths: list[Path]) -> dict[int, float]: |
| 232 | + """Parse every JSONL, dedupe-by-step, return ``{step: loss}``. |
| 233 | + |
| 234 | + Resumed runs append to a fresh JSONL but share step numbers with |
| 235 | + the prior run. Dedupe-by-keep-latest matches dlm's own ``dlm |
| 236 | + metrics`` semantics: the most recent run for a given step wins. |
| 237 | + """ |
| 238 | + out: dict[int, float] = {} |
| 239 | + for path in log_paths: |
| 240 | + try: |
| 241 | + with path.open("r", encoding="utf-8") as f: |
| 242 | + for line_no, raw in enumerate(f, start=1): |
| 243 | + raw = raw.strip() |
| 244 | + if not raw: |
| 245 | + continue |
| 246 | + try: |
| 247 | + rec = json.loads(raw) |
| 248 | + except json.JSONDecodeError as exc: |
| 249 | + # A trailing partial line from a crashed |
| 250 | + # trainer is the typical cause. Skip it but |
| 251 | + # surface as ERROR if EVERY line is broken |
| 252 | + # (caller checks ``out`` emptiness). |
| 253 | + if line_no == 1: |
| 254 | + raise TrainingDriftError( |
| 255 | + f"first line of {path.name} is not valid JSON: {exc}" |
| 256 | + ) from exc |
| 257 | + continue |
| 258 | + if not isinstance(rec, dict): |
| 259 | + continue |
| 260 | + if rec.get("type") != "step": |
| 261 | + continue |
| 262 | + try: |
| 263 | + step = int(rec["step"]) |
| 264 | + loss = float(rec["loss"]) |
| 265 | + except (KeyError, TypeError, ValueError): |
| 266 | + continue |
| 267 | + if not math.isfinite(loss): |
| 268 | + # NaN loss is a real signal — record it as |
| 269 | + # +inf so the spike detector flags the |
| 270 | + # instability without crashing on np.log. |
| 271 | + loss = math.inf |
| 272 | + out[step] = loss |
| 273 | + except OSError as exc: |
| 274 | + raise TrainingDriftError(f"failed to read {path}: {exc}") from exc |
| 275 | + return out |
| 276 | + |
| 277 | + |
| 278 | +# --------------------------------------------------------------------------- |
| 279 | +# Metric computation |
| 280 | +# --------------------------------------------------------------------------- |
| 281 | + |
| 282 | + |
| 283 | +def _compute_metrics( |
| 284 | + losses: np.ndarray, |
| 285 | + *, |
| 286 | + rolling_window: int, |
| 287 | + spike_sigma: float, |
| 288 | +) -> dict[str, float]: |
| 289 | + """Compute the four headline metrics from the loss array. |
| 290 | + |
| 291 | + All metrics are robust to NaN/inf: non-finite step losses are |
| 292 | + replaced with the most-recent finite value before fitting (so a |
| 293 | + single bad batch doesn't poison the curve), but their *positions* |
| 294 | + are still counted as instability events. |
| 295 | + """ |
| 296 | + losses = losses.astype(np.float64, copy=True) |
| 297 | + instability_from_nan = int(np.sum(~np.isfinite(losses))) |
| 298 | + |
| 299 | + if instability_from_nan > 0: |
| 300 | + # Forward-fill non-finite values from the previous finite |
| 301 | + # entry so downstream stats don't blow up. The first entry |
| 302 | + # MUST be finite (training records the first batch's loss |
| 303 | + # before anything could go wrong); guard anyway. |
| 304 | + finite_mask = np.isfinite(losses) |
| 305 | + if not finite_mask.any(): |
| 306 | + # Pathological: every step recorded NaN. Return all-NaN |
| 307 | + # metrics so the verdict path can surface the failure. |
| 308 | + return { |
| 309 | + "initial_loss": float("nan"), |
| 310 | + "final_loss": float("nan"), |
| 311 | + "convergence_ratio": float("nan"), |
| 312 | + "smoothness": 0.0, |
| 313 | + "instability_events": float(len(losses)), |
| 314 | + } |
| 315 | + last_good = float(losses[finite_mask][0]) |
| 316 | + for i, v in enumerate(losses): |
| 317 | + if not math.isfinite(v): |
| 318 | + losses[i] = last_good |
| 319 | + else: |
| 320 | + last_good = float(v) |
| 321 | + |
| 322 | + initial_loss = float(losses[0]) |
| 323 | + final_loss = float(losses[-1]) |
| 324 | + convergence_ratio = float(final_loss / initial_loss) if initial_loss != 0.0 else float("inf") |
| 325 | + |
| 326 | + deltas = np.diff(losses) |
| 327 | + var_loss = float(losses.var()) |
| 328 | + var_delta = float(deltas.var()) if deltas.size > 0 else 0.0 |
| 329 | + if var_loss > 0.0: # noqa: SIM108 — branch comments are load-bearing |
| 330 | + # Clip into [0, 1]: a curve where var(Δloss) > var(loss) |
| 331 | + # implies the per-step jitter dominates the overall sweep — |
| 332 | + # treat as fully-spiky (smoothness=0) rather than emit a |
| 333 | + # negative number. |
| 334 | + smoothness = max(0.0, 1.0 - var_delta / var_loss) |
| 335 | + else: |
| 336 | + # Identical losses across every step: the run never |
| 337 | + # progressed. Smoothness is formally 1.0 (perfectly flat), |
| 338 | + # but that's misleading — surface as 0.0 so the verdict path |
| 339 | + # can flag it. |
| 340 | + smoothness = 0.0 |
| 341 | + |
| 342 | + instability_events = _count_spikes(deltas, window=rolling_window, sigma=spike_sigma) |
| 343 | + instability_events += instability_from_nan |
| 344 | + |
| 345 | + return { |
| 346 | + "initial_loss": initial_loss, |
| 347 | + "final_loss": final_loss, |
| 348 | + "convergence_ratio": convergence_ratio, |
| 349 | + "smoothness": smoothness, |
| 350 | + "instability_events": float(instability_events), |
| 351 | + } |
| 352 | + |
| 353 | + |
| 354 | +def _count_spikes(deltas: np.ndarray, *, window: int, sigma: float) -> int: |
| 355 | + """Count loss-increase events that exceed a robust noise threshold. |
| 356 | + |
| 357 | + The naive ``|Δloss| > sigma · rolling_std`` heuristic is broken for |
| 358 | + exponential decay: deltas span orders of magnitude across training, |
| 359 | + so within-window std stays tiny while absolute deltas are large — |
| 360 | + every step trips the threshold. The semantically correct |
| 361 | + "instability event" is a loss *increase*, not a faster-than-typical |
| 362 | + decrease. |
| 363 | + |
| 364 | + Heuristic: |
| 365 | + |
| 366 | + 1. Restrict to positive deltas (``delta > 0`` ⇒ loss went up). Loss |
| 367 | + going down faster than usual isn't an instability — it's the |
| 368 | + happy path. |
| 369 | + 2. For each positive delta, compare to the median absolute delta |
| 370 | + in a centered window. Flag when ``delta > sigma · MAD``, where |
| 371 | + MAD is the median absolute deviation (robust to outliers). |
| 372 | + 3. Short curves fall back to global MAD; constant-loss curves |
| 373 | + (every delta zero) report 0 spikes. |
| 374 | + |
| 375 | + The ``sigma`` parameter retains its semantic meaning ("how many |
| 376 | + typical deltas does this exceed") but operates against the |
| 377 | + median-absolute-delta scale rather than std-of-delta. |
| 378 | + """ |
| 379 | + if deltas.size == 0: |
| 380 | + return 0 |
| 381 | + |
| 382 | + # Median absolute delta — the "typical movement scale" we judge |
| 383 | + # spikes against. Median is robust to a few outliers; if the |
| 384 | + # whole curve is flat, MAD is 0 and we report no spikes. |
| 385 | + abs_deltas = np.abs(deltas) |
| 386 | + |
| 387 | + if deltas.size < window: |
| 388 | + baseline = float(np.median(abs_deltas)) |
| 389 | + if baseline == 0.0: |
| 390 | + return 0 |
| 391 | + return int(np.sum((deltas > 0.0) & (deltas > sigma * baseline))) |
| 392 | + |
| 393 | + spikes = 0 |
| 394 | + half = window // 2 |
| 395 | + for i in range(deltas.size): |
| 396 | + if deltas[i] <= 0.0: |
| 397 | + continue # Loss went down — not an instability event. |
| 398 | + lo = max(0, i - half) |
| 399 | + hi = min(deltas.size, i + half + 1) |
| 400 | + window_slice = abs_deltas[lo:hi] |
| 401 | + # Exclude the point itself so a real outlier doesn't inflate |
| 402 | + # the baseline it's measured against. |
| 403 | + if window_slice.size > 1: |
| 404 | + mask = np.ones(window_slice.size, dtype=bool) |
| 405 | + mask[i - lo] = False |
| 406 | + window_slice = window_slice[mask] |
| 407 | + baseline = float(np.median(window_slice)) |
| 408 | + if baseline == 0.0: |
| 409 | + # Surrounding deltas are all zero — any nonzero positive |
| 410 | + # delta is an instability event by construction. |
| 411 | + spikes += 1 |
| 412 | + continue |
| 413 | + if float(deltas[i]) > sigma * baseline: |
| 414 | + spikes += 1 |
| 415 | + return spikes |
| 416 | + |
| 417 | + |
| 418 | +# --------------------------------------------------------------------------- |
| 419 | +# Verdict mapping |
| 420 | +# --------------------------------------------------------------------------- |
| 421 | + |
| 422 | + |
| 423 | +def _verdict_from_metrics( |
| 424 | + metrics: dict[str, float], spec: TrainingDriftSpec |
| 425 | +) -> tuple[Verdict, float, str]: |
| 426 | + """Map the four metrics to a (verdict, score, message).""" |
| 427 | + smooth = metrics["smoothness"] |
| 428 | + conv = metrics["convergence_ratio"] |
| 429 | + instability = int(metrics["instability_events"]) |
| 430 | + |
| 431 | + smoothness_pass = smooth >= spec.assert_smoothness_gte |
| 432 | + convergence_pass = conv <= spec.assert_convergence_ratio_lte |
| 433 | + instability_pass = instability <= spec.assert_instability_events_lte |
| 434 | + |
| 435 | + failures: list[str] = [] |
| 436 | + if not smoothness_pass: |
| 437 | + failures.append(f"smoothness={smooth:.2f} < {spec.assert_smoothness_gte}") |
| 438 | + if not convergence_pass: |
| 439 | + failures.append(f"convergence_ratio={conv:.2f} > {spec.assert_convergence_ratio_lte}") |
| 440 | + if not instability_pass: |
| 441 | + failures.append(f"instability_events={instability} > {spec.assert_instability_events_lte}") |
| 442 | + |
| 443 | + headline = ( |
| 444 | + f"smoothness={smooth:.2f}, convergence_ratio={conv:.2f}, " |
| 445 | + f"instability_events={instability}, final_loss={metrics['final_loss']:.3f}" |
| 446 | + ) |
| 447 | + |
| 448 | + if not failures: |
| 449 | + # All three thresholds clear → PASS with a normalized score |
| 450 | + # blending the three signals (smoothness contributes most; |
| 451 | + # convergence and instability are guard rails). |
| 452 | + score = float(min(1.0, max(0.0, smooth))) |
| 453 | + return Verdict.PASS, score, headline |
| 454 | + |
| 455 | + # Score: a continuous blend of how far we are from each threshold, |
| 456 | + # so the report can rank "borderline warn" against "actively bad." |
| 457 | + # Doesn't influence the verdict — that's already FAIL/WARN. |
| 458 | + score = float(min(1.0, max(0.0, smooth * 0.5))) |
| 459 | + return Verdict.WARN, score, f"{headline}; warnings: {'; '.join(failures)}" |
| 460 | + |
| 461 | + |
| 462 | +# --------------------------------------------------------------------------- |
| 463 | +# Curve downsampling for evidence |
| 464 | +# --------------------------------------------------------------------------- |
| 465 | + |
| 466 | + |
| 467 | +def _downsampled_curve( |
| 468 | + steps: np.ndarray, losses: np.ndarray, *, cap: int |
| 469 | +) -> list[tuple[int, float]]: |
| 470 | + """Uniform-stride downsample so a 10k-step run still fits in the JSON report. |
| 471 | + |
| 472 | + Always preserves the first and last point so initial/final loss |
| 473 | + are visible regardless of cap. For curves shorter than the cap, |
| 474 | + returns the full series unchanged. |
| 475 | + """ |
| 476 | + n = int(len(losses)) |
| 477 | + if n <= cap: |
| 478 | + return [(int(s), float(loss)) for s, loss in zip(steps, losses, strict=True)] |
| 479 | + # Stride to keep at most ``cap`` points: stride = ceil(n / cap). |
| 480 | + # The +1 accounts for always-appending the final point even when |
| 481 | + # ``stride * (cap - 1) < n - 1``. |
| 482 | + stride = max(1, (n + cap - 1) // cap) |
| 483 | + idx = list(range(0, n, stride)) |
| 484 | + if idx[-1] != n - 1: |
| 485 | + idx.append(n - 1) |
| 486 | + return [(int(steps[i]), float(losses[i])) for i in idx] |