Python · 2662 bytes Raw Blame History
1 """Val loss → perplexity adapter for HF `compute_metrics`.
2
3 SFTTrainer calls `compute_metrics(eval_pred)` at every eval step when
4 we pass a callable as `SFTConfig.compute_metrics_for_all_tokens` /
5 `Trainer.compute_metrics`. The callable receives an `EvalPrediction`
6 namespace whose `.predictions` and `.label_ids` are post-batched tensors.
7
8 For language modeling we don't actually need the predictions — HF has
9 already computed the eval loss by the time `compute_metrics` fires, and
10 exposes it as `trainer.state.log_history[-1]["eval_loss"]`. The
11 `compute_metrics` hook exists so we can add derived metrics (perplexity)
12 that HF then logs alongside.
13
14 This module exports a single callable `eval_metrics_for_state(state)`
15 that pulls `eval_loss` out of the trainer state's log history and
16 returns the PPL dict; the trainer wires it directly into `SFTConfig`
17 as a closure.
18 """
19
20 from __future__ import annotations
21
22 from typing import Any
23
24 from dlm.eval.perplexity import perplexity
25
26
27 def eval_metrics_from_eval_pred(eval_pred: Any) -> dict[str, float]:
28 """Compute-metrics hook compatible with `Trainer.compute_metrics`.
29
30 `eval_pred` is expected to be an `EvalPrediction`-like object; we
31 only inspect `.metrics` (set by HF's internal eval loop after loss
32 has been computed). If `metrics` isn't present we return an empty
33 dict — the HF side will still log `eval_loss` itself.
34 """
35 metrics = getattr(eval_pred, "metrics", None) or {}
36 loss = metrics.get("eval_loss")
37 if not isinstance(loss, (int, float)):
38 return {}
39 return {"perplexity": perplexity(float(loss))}
40
41
42 def summarize_eval_state(log_history: list[dict[str, Any]]) -> dict[str, float | None]:
43 """Extract `final_val_loss` + `final_val_perplexity` from trainer history.
44
45 `log_history` is `trainer.state.log_history` — a list of dicts, one
46 per logged metric snapshot. The last entry containing a finite
47 `eval_loss` is the authoritative final eval result. Non-finite
48 (NaN/inf) values are dropped silently — `dlm.train.integrity.
49 assert_eval_finite` raises separately if the trainer produced a
50 NaN eval so the run is marked FAILED at the orchestration layer.
51 """
52 import math
53
54 final_loss: float | None = None
55 for entry in reversed(log_history):
56 value = entry.get("eval_loss")
57 if isinstance(value, (int, float)):
58 f = float(value)
59 if not math.isfinite(f):
60 continue
61 final_loss = f
62 break
63 final_ppl = perplexity(final_loss) if final_loss is not None else None
64 return {"final_val_loss": final_loss, "final_val_perplexity": final_ppl}