| 1 |
"""Post-training val-loss split by row mode. |
| 2 |
|
| 3 |
`TrainingSummary.val_loss_cpt` / `val_loss_sft` and |
| 4 |
`split_loss_by_mode` capture the mixed-mode breakdown. This module |
| 5 |
closes the loop: given a trained `SFTTrainer` and its `val_ds`, split |
| 6 |
the dataset by row mode (CPT prose vs SFT instruction) and run |
| 7 |
`trainer.evaluate()` on each non-empty subset to extract per-mode |
| 8 |
`eval_loss`. |
| 9 |
|
| 10 |
Kept small and pure-wrapper. Heavy eval lives in TRL; we just |
| 11 |
group rows and read `eval_loss` out of the returned dict. Unit |
| 12 |
tests drive the grouping logic with a mock trainer. |
| 13 |
""" |
| 14 |
|
| 15 |
from __future__ import annotations |
| 16 |
|
| 17 |
import logging |
| 18 |
from typing import Any |
| 19 |
|
| 20 |
_LOG = logging.getLogger(__name__) |
| 21 |
|
| 22 |
|
| 23 |
def compute_val_loss_by_mode(trainer: Any, val_ds: Any) -> tuple[float | None, float | None]: |
| 24 |
"""Return `(val_loss_cpt, val_loss_sft)` from a post-train eval pass. |
| 25 |
|
| 26 |
Splits `val_ds` into CPT-only and SFT-only subsets using the |
| 27 |
`dlm.train.cpt.runtime.row_mode` classifier, runs |
| 28 |
`trainer.evaluate()` on each non-empty subset, and returns the |
| 29 |
resulting `eval_loss` values. `None` for any mode with no rows |
| 30 |
in the val set. |
| 31 |
|
| 32 |
Non-fatal: if the eval call raises (stack version drift, NaN |
| 33 |
logits from an undertrained tiny model, etc.) the affected |
| 34 |
mode's loss stays `None`. The summary gets whatever signal is |
| 35 |
reliably extractable without killing the training run. |
| 36 |
""" |
| 37 |
from dlm.train.cpt.runtime import row_mode |
| 38 |
|
| 39 |
if val_ds is None: |
| 40 |
return (None, None) |
| 41 |
try: |
| 42 |
if len(val_ds) == 0: |
| 43 |
return (None, None) |
| 44 |
except TypeError: |
| 45 |
# Not-sized dataset: bail gracefully rather than crashing. |
| 46 |
return (None, None) |
| 47 |
|
| 48 |
# Group indices by mode — we filter via HF Dataset.select() so we |
| 49 |
# don't duplicate rows into memory. |
| 50 |
cpt_idx: list[int] = [] |
| 51 |
sft_idx: list[int] = [] |
| 52 |
for i, row in enumerate(val_ds): |
| 53 |
mode = row_mode(row) |
| 54 |
if mode == "cpt": |
| 55 |
cpt_idx.append(i) |
| 56 |
elif mode == "sft": |
| 57 |
sft_idx.append(i) |
| 58 |
|
| 59 |
cpt_loss = _safe_eval_loss(trainer, val_ds, cpt_idx, mode="cpt") |
| 60 |
sft_loss = _safe_eval_loss(trainer, val_ds, sft_idx, mode="sft") |
| 61 |
return (cpt_loss, sft_loss) |
| 62 |
|
| 63 |
|
| 64 |
def _safe_eval_loss(trainer: Any, val_ds: Any, indices: list[int], *, mode: str) -> float | None: |
| 65 |
"""Run `trainer.evaluate(eval_dataset=subset)`; return eval_loss or None.""" |
| 66 |
if not indices: |
| 67 |
return None |
| 68 |
try: |
| 69 |
subset = val_ds.select(indices) |
| 70 |
except (AttributeError, IndexError, TypeError, ValueError) as exc: |
| 71 |
_LOG.warning( |
| 72 |
"val-loss split skipped %s subset selection (%d rows): %s", |
| 73 |
mode, |
| 74 |
len(indices), |
| 75 |
exc, |
| 76 |
) |
| 77 |
return None |
| 78 |
try: |
| 79 |
metrics = trainer.evaluate(eval_dataset=subset) |
| 80 |
except (RuntimeError, TypeError, ValueError) as exc: |
| 81 |
_LOG.warning( |
| 82 |
"val-loss split skipped %s evaluation (%d rows): %s", |
| 83 |
mode, |
| 84 |
len(indices), |
| 85 |
exc, |
| 86 |
) |
| 87 |
return None |
| 88 |
loss = metrics.get("eval_loss") if isinstance(metrics, dict) else None |
| 89 |
if loss is None: |
| 90 |
return None |
| 91 |
try: |
| 92 |
return float(loss) |
| 93 |
except (TypeError, ValueError): |
| 94 |
return None |