Python · 3217 bytes Raw Blame History
1 """Post-training val-loss split by row mode.
2
3 `TrainingSummary.val_loss_cpt` / `val_loss_sft` and
4 `split_loss_by_mode` capture the mixed-mode breakdown. This module
5 closes the loop: given a trained `SFTTrainer` and its `val_ds`, split
6 the dataset by row mode (CPT prose vs SFT instruction) and run
7 `trainer.evaluate()` on each non-empty subset to extract per-mode
8 `eval_loss`.
9
10 Kept small and pure-wrapper. Heavy eval lives in TRL; we just
11 group rows and read `eval_loss` out of the returned dict. Unit
12 tests drive the grouping logic with a mock trainer.
13 """
14
15 from __future__ import annotations
16
17 import logging
18 from typing import Any
19
20 _LOG = logging.getLogger(__name__)
21
22
23 def compute_val_loss_by_mode(trainer: Any, val_ds: Any) -> tuple[float | None, float | None]:
24 """Return `(val_loss_cpt, val_loss_sft)` from a post-train eval pass.
25
26 Splits `val_ds` into CPT-only and SFT-only subsets using the
27 `dlm.train.cpt.runtime.row_mode` classifier, runs
28 `trainer.evaluate()` on each non-empty subset, and returns the
29 resulting `eval_loss` values. `None` for any mode with no rows
30 in the val set.
31
32 Non-fatal: if the eval call raises (stack version drift, NaN
33 logits from an undertrained tiny model, etc.) the affected
34 mode's loss stays `None`. The summary gets whatever signal is
35 reliably extractable without killing the training run.
36 """
37 from dlm.train.cpt.runtime import row_mode
38
39 if val_ds is None:
40 return (None, None)
41 try:
42 if len(val_ds) == 0:
43 return (None, None)
44 except TypeError:
45 # Not-sized dataset: bail gracefully rather than crashing.
46 return (None, None)
47
48 # Group indices by mode — we filter via HF Dataset.select() so we
49 # don't duplicate rows into memory.
50 cpt_idx: list[int] = []
51 sft_idx: list[int] = []
52 for i, row in enumerate(val_ds):
53 mode = row_mode(row)
54 if mode == "cpt":
55 cpt_idx.append(i)
56 elif mode == "sft":
57 sft_idx.append(i)
58
59 cpt_loss = _safe_eval_loss(trainer, val_ds, cpt_idx, mode="cpt")
60 sft_loss = _safe_eval_loss(trainer, val_ds, sft_idx, mode="sft")
61 return (cpt_loss, sft_loss)
62
63
64 def _safe_eval_loss(trainer: Any, val_ds: Any, indices: list[int], *, mode: str) -> float | None:
65 """Run `trainer.evaluate(eval_dataset=subset)`; return eval_loss or None."""
66 if not indices:
67 return None
68 try:
69 subset = val_ds.select(indices)
70 except (AttributeError, IndexError, TypeError, ValueError) as exc:
71 _LOG.warning(
72 "val-loss split skipped %s subset selection (%d rows): %s",
73 mode,
74 len(indices),
75 exc,
76 )
77 return None
78 try:
79 metrics = trainer.evaluate(eval_dataset=subset)
80 except (RuntimeError, TypeError, ValueError) as exc:
81 _LOG.warning(
82 "val-loss split skipped %s evaluation (%d rows): %s",
83 mode,
84 len(indices),
85 exc,
86 )
87 return None
88 loss = metrics.get("eval_loss") if isinstance(metrics, dict) else None
89 if loss is None:
90 return None
91 try:
92 return float(loss)
93 except (TypeError, ValueError):
94 return None