@@ -0,0 +1,142 @@ |
| 1 | +"""Post-training weight-integrity gate. |
| 2 | + |
| 3 | +Invariant |
| 4 | +--------- |
| 5 | + |
| 6 | +Training must never persist an adapter whose weights contain NaN or |
| 7 | +infinite values. Downstream consumers (inference, export, evaluation) |
| 8 | +would then silently produce NaN logits — and the `dlm train` exit code |
| 9 | +would be 0, giving the user no signal that the run failed. |
| 10 | + |
| 11 | +This module walks the trainable parameters of a wrapped PEFT model and |
| 12 | +asserts every element is finite. The check runs inside the |
| 13 | +`_write_checkpoint` writer, between `save_model` and the caller's |
| 14 | +atomic pointer flip. If the check fails, the exception propagates up |
| 15 | +through `commit_version` — which by contract leaves the pending |
| 16 | +version directory on disk and does **not** flip `current.txt`. The |
| 17 | +trainer then renames the pending dir to a `-rejected` suffix so the |
| 18 | +next `allocate_next_version` skips it (the suffix makes the directory |
| 19 | +name unparseable as `vNNNN`) and the user can inspect the bad weights |
| 20 | +for postmortem. |
| 21 | + |
| 22 | +Only trainable parameters are checked. Frozen base-model weights are |
| 23 | +both huge and (by construction) unchanged by training, so validating |
| 24 | +them is wasted I/O without corresponding signal. |
| 25 | +""" |
| 26 | + |
| 27 | +from __future__ import annotations |
| 28 | + |
| 29 | +from dataclasses import dataclass |
| 30 | +from typing import TYPE_CHECKING, Any |
| 31 | + |
| 32 | +from dlm.train.errors import TrainingError |
| 33 | + |
| 34 | +if TYPE_CHECKING: |
| 35 | + pass |
| 36 | + |
| 37 | + |
| 38 | +class NaNEvalError(TrainingError): |
| 39 | + """The trainer's final eval_loss was NaN or infinite — refusing to commit. |
| 40 | + |
| 41 | + Fires in the trainer orchestrator after `summarize_eval_state` and |
| 42 | + before the two-phase commit. Redundant with `NaNWeightsError`: when |
| 43 | + eval diverges, the weights are usually already corrupt, so both |
| 44 | + gates catch the same failure from different angles. The redundancy |
| 45 | + is intentional — catching it twice is cheap, catching it zero times |
| 46 | + (the original bug) silently poisons downstream consumers. |
| 47 | + """ |
| 48 | + |
| 49 | + def __init__(self, value: float) -> None: |
| 50 | + super().__init__( |
| 51 | + f"final eval_loss is non-finite ({value!r}) — refusing to commit " |
| 52 | + "adapter. Lower LR, add warmup, or check dataset for NaN-producing inputs." |
| 53 | + ) |
| 54 | + self.value = value |
| 55 | + |
| 56 | + |
| 57 | +def assert_eval_finite(log_history: list[Any]) -> None: |
| 58 | + """Raise `NaNEvalError` if the last `eval_loss` entry is non-finite. |
| 59 | + |
| 60 | + No-op when `log_history` is empty or contains no eval entries — |
| 61 | + eval is optional, and runs without a `eval_dataset` legitimately |
| 62 | + have no eval_loss to check. |
| 63 | + """ |
| 64 | + import math |
| 65 | + |
| 66 | + for entry in reversed(log_history): |
| 67 | + if not isinstance(entry, dict): |
| 68 | + continue |
| 69 | + value = entry.get("eval_loss") |
| 70 | + if isinstance(value, (int, float)): |
| 71 | + if not math.isfinite(float(value)): |
| 72 | + raise NaNEvalError(float(value)) |
| 73 | + return # last eval was finite — good |
| 74 | + # No eval entries — the caller's contract is "check iff eval ran". |
| 75 | + |
| 76 | + |
| 77 | +class NaNWeightsError(TrainingError): |
| 78 | + """A trained adapter contains non-finite weights — refusing to persist. |
| 79 | + |
| 80 | + `offending` is the list of parameter names that failed the check. |
| 81 | + Truncated to the first 20 names for readability; the full list is |
| 82 | + available via the `full_offending` attribute. |
| 83 | + """ |
| 84 | + |
| 85 | + def __init__(self, offending: list[str]) -> None: |
| 86 | + display = offending[:20] |
| 87 | + suffix = f" (and {len(offending) - 20} more)" if len(offending) > 20 else "" |
| 88 | + super().__init__( |
| 89 | + f"adapter weights contain NaN/inf in {len(offending)} tensor(s) — " |
| 90 | + f"refusing to persist. Offenders: {', '.join(display)}{suffix}" |
| 91 | + ) |
| 92 | + self.offending = display |
| 93 | + self.full_offending = offending |
| 94 | + |
| 95 | + |
| 96 | +@dataclass(frozen=True) |
| 97 | +class FiniteCheckResult: |
| 98 | + """Outcome of a finite-weight audit, used by unit tests + the gate.""" |
| 99 | + |
| 100 | + checked: int |
| 101 | + offending: tuple[str, ...] |
| 102 | + |
| 103 | + @property |
| 104 | + def ok(self) -> bool: |
| 105 | + return not self.offending |
| 106 | + |
| 107 | + |
| 108 | +def audit_trainable_finite(model: Any) -> FiniteCheckResult: |
| 109 | + """Walk `model.named_parameters()` and flag non-finite trainable tensors. |
| 110 | + |
| 111 | + Only `requires_grad=True` parameters are inspected — these are the |
| 112 | + LoRA `lora_A` / `lora_B` / `modules_to_save` tensors that training |
| 113 | + actually updates. Frozen base weights are skipped. |
| 114 | + |
| 115 | + Returns a `FiniteCheckResult` even when everything is finite, so |
| 116 | + callers can log the count of checked tensors. |
| 117 | + """ |
| 118 | + import torch |
| 119 | + |
| 120 | + offending: list[str] = [] |
| 121 | + checked = 0 |
| 122 | + for name, param in model.named_parameters(): |
| 123 | + if not getattr(param, "requires_grad", False): |
| 124 | + continue |
| 125 | + checked += 1 |
| 126 | + if not torch.isfinite(param).all(): |
| 127 | + offending.append(name) |
| 128 | + return FiniteCheckResult(checked=checked, offending=tuple(offending)) |
| 129 | + |
| 130 | + |
| 131 | +def assert_finite_adapter(model: Any) -> None: |
| 132 | + """Raise `NaNWeightsError` if any trainable parameter is non-finite. |
| 133 | + |
| 134 | + Called from `_write_checkpoint` right after `sft.save_model()`. If |
| 135 | + the weights are bad, the version directory on disk still contains |
| 136 | + the saved (bad) adapter_model.safetensors — the trainer relies on |
| 137 | + `commit_version` leaving the pending dir alone when the writer |
| 138 | + raises, and then renames it to `{pending}-rejected` for postmortem. |
| 139 | + """ |
| 140 | + result = audit_trainable_finite(model) |
| 141 | + if not result.ok: |
| 142 | + raise NaNWeightsError(list(result.offending)) |