| 1 |
"""Val loss → perplexity adapter for HF `compute_metrics`. |
| 2 |
|
| 3 |
SFTTrainer calls `compute_metrics(eval_pred)` at every eval step when |
| 4 |
we pass a callable as `SFTConfig.compute_metrics_for_all_tokens` / |
| 5 |
`Trainer.compute_metrics`. The callable receives an `EvalPrediction` |
| 6 |
namespace whose `.predictions` and `.label_ids` are post-batched tensors. |
| 7 |
|
| 8 |
For language modeling we don't actually need the predictions — HF has |
| 9 |
already computed the eval loss by the time `compute_metrics` fires, and |
| 10 |
exposes it as `trainer.state.log_history[-1]["eval_loss"]`. The |
| 11 |
`compute_metrics` hook exists so we can add derived metrics (perplexity) |
| 12 |
that HF then logs alongside. |
| 13 |
|
| 14 |
This module exports a single callable `eval_metrics_for_state(state)` |
| 15 |
that pulls `eval_loss` out of the trainer state's log history and |
| 16 |
returns the PPL dict; the trainer wires it directly into `SFTConfig` |
| 17 |
as a closure. |
| 18 |
""" |
| 19 |
|
| 20 |
from __future__ import annotations |
| 21 |
|
| 22 |
from typing import Any |
| 23 |
|
| 24 |
from dlm.eval.perplexity import perplexity |
| 25 |
|
| 26 |
|
| 27 |
def eval_metrics_from_eval_pred(eval_pred: Any) -> dict[str, float]: |
| 28 |
"""Compute-metrics hook compatible with `Trainer.compute_metrics`. |
| 29 |
|
| 30 |
`eval_pred` is expected to be an `EvalPrediction`-like object; we |
| 31 |
only inspect `.metrics` (set by HF's internal eval loop after loss |
| 32 |
has been computed). If `metrics` isn't present we return an empty |
| 33 |
dict — the HF side will still log `eval_loss` itself. |
| 34 |
""" |
| 35 |
metrics = getattr(eval_pred, "metrics", None) or {} |
| 36 |
loss = metrics.get("eval_loss") |
| 37 |
if not isinstance(loss, (int, float)): |
| 38 |
return {} |
| 39 |
return {"perplexity": perplexity(float(loss))} |
| 40 |
|
| 41 |
|
| 42 |
def summarize_eval_state(log_history: list[dict[str, Any]]) -> dict[str, float | None]: |
| 43 |
"""Extract `final_val_loss` + `final_val_perplexity` from trainer history. |
| 44 |
|
| 45 |
`log_history` is `trainer.state.log_history` — a list of dicts, one |
| 46 |
per logged metric snapshot. The last entry containing a finite |
| 47 |
`eval_loss` is the authoritative final eval result. Non-finite |
| 48 |
(NaN/inf) values are dropped silently — `dlm.train.integrity. |
| 49 |
assert_eval_finite` raises separately if the trainer produced a |
| 50 |
NaN eval so the run is marked FAILED at the orchestration layer. |
| 51 |
""" |
| 52 |
import math |
| 53 |
|
| 54 |
final_loss: float | None = None |
| 55 |
for entry in reversed(log_history): |
| 56 |
value = entry.get("eval_loss") |
| 57 |
if isinstance(value, (int, float)): |
| 58 |
f = float(value) |
| 59 |
if not math.isfinite(f): |
| 60 |
continue |
| 61 |
final_loss = f |
| 62 |
break |
| 63 |
final_ppl = perplexity(final_loss) if final_loss is not None else None |
| 64 |
return {"final_val_loss": final_loss, "final_val_perplexity": final_ppl} |