tenseleyflow/sway / 58b0322

Browse files

cli/check: gradient_ghost in quick battery + pre-flight banner on FAIL/WARN (S25 P6)

Authored by mfwolffe <wolffemf@dukes.jmu.edu>
SHA
58b032270a4e936517039fffb62b57367caf6830
Parents
cb71687
Tree
7eae4ef

1 changed file

StatusFile+-
M src/dlm_sway/cli/commands.py 55 0
src/dlm_sway/cli/commands.pymodified
@@ -347,6 +347,15 @@ def check_cmd(
347347
         models=SuiteModels(base=base_spec, ft=ft_spec),
348348
         defaults=SuiteDefaults(seed=0),
349349
         suite=[
350
+            # S25: pre-run training-health check first. SKIPs cleanly
351
+            # when the adapter wasn't produced by dlm (no
352
+            # training_state.pt); FAILs loudly on severely-undertrained
353
+            # adapters with a banner before the rest of the output.
354
+            {
355
+                "name": "quick_gradient_ghost",
356
+                "kind": "gradient_ghost",
357
+                "adapter_path": str(adapter),
358
+            },
350359
             # Calibrate first so delta_kl can publish a z-score the
351360
             # banner reads off.
352361
             {"name": "quick_null", "kind": "null_adapter", "runs": 3},
@@ -378,6 +387,13 @@ def check_cmd(
378387
     # D12: top-line banner before the full report so a user looking
379388
     # only at the first line still gets the verdict.
380389
     console = Console()
390
+
391
+    # S25 — pre-flight gradient_ghost banner. Fires BEFORE the verdict
392
+    # banner so the user sees "this adapter is undertrained" first;
393
+    # the rest of the check output stays for context (the user might
394
+    # still want to see how badly the other probes scored).
395
+    _emit_gradient_ghost_banner(result, console)
396
+
381397
     banner_text, banner_style = _check_banner(score_obj, result)
382398
     console.print()
383399
     console.print(banner_text, style=banner_style)
@@ -385,6 +401,45 @@ def check_cmd(
385401
     report.to_terminal(result, score_obj, console=console)
386402
 
387403
 
404
+def _emit_gradient_ghost_banner(result: object, console: Console) -> None:
405
+    """Print a yellow/red ⚠️ banner if gradient_ghost FAILed (S25 P6).
406
+
407
+    Reaches into ``result.probes`` for any probe with
408
+    ``kind=gradient_ghost`` and verdict FAIL. Informational — no
409
+    effect on exit code; the user might still want to inspect the
410
+    other probes' verdicts.
411
+    """
412
+    probes = getattr(result, "probes", ()) or ()
413
+    for p in probes:
414
+        if getattr(p, "kind", "") != "gradient_ghost":
415
+            continue
416
+        verdict_str = str(getattr(p, "verdict", "")).lower()
417
+        if verdict_str == "fail":
418
+            console.print()
419
+            console.print(
420
+                "⚠️  PRE-RUN ALERT — gradient_ghost flagged severe undertraining",
421
+                style="bold red",
422
+            )
423
+            msg = getattr(p, "message", "")
424
+            if msg:
425
+                console.print(f"   {msg}", style="red")
426
+            console.print(
427
+                "   The probe scores below may be unreliable. Consider retraining.",
428
+                style="dim red",
429
+            )
430
+            return
431
+        if verdict_str == "warn":
432
+            console.print()
433
+            console.print(
434
+                "⚠️  gradient_ghost: training may not have fully converged",
435
+                style="bold yellow",
436
+            )
437
+            msg = getattr(p, "message", "")
438
+            if msg:
439
+                console.print(f"   {msg}", style="yellow")
440
+            return
441
+
442
+
388443
 def diff_cmd(
389444
     spec: Annotated[Path, typer.Argument(help="Path to a sway.yaml spec.")],
390445
     adapter_a: Annotated[Path, typer.Option("--a", help="First adapter path.")],