| 1 |
"""Early stopping — thin wrapper over `transformers.EarlyStoppingCallback`. |
| 2 |
|
| 3 |
The HF callback itself does the work (monitors `metric_for_best_model`, |
| 4 |
increments a patience counter, sets `control.should_training_stop` when |
| 5 |
the counter exceeds `early_stopping_patience`). We wrap it for two |
| 6 |
reasons: |
| 7 |
|
| 8 |
1. **Config validation.** Patience must be ≥1; threshold must be |
| 9 |
finite and ≥0. HF accepts nonsense defaults and silently degrades. |
| 10 |
2. **Inspection.** After `trainer.train()` returns, we need to know |
| 11 |
whether the run actually early-stopped (vs. completing normally) |
| 12 |
so the summary can record it. HF exposes this via |
| 13 |
`trainer.state.best_metric` + `global_step`; we bundle the check |
| 14 |
in `was_early_stopped()` so downstream callers don't reach into |
| 15 |
trainer internals. |
| 16 |
""" |
| 17 |
|
| 18 |
from __future__ import annotations |
| 19 |
|
| 20 |
from dataclasses import dataclass |
| 21 |
from typing import TYPE_CHECKING, Any |
| 22 |
|
| 23 |
if TYPE_CHECKING: |
| 24 |
pass |
| 25 |
|
| 26 |
|
| 27 |
@dataclass(frozen=True) |
| 28 |
class EarlyStopConfig: |
| 29 |
"""Knobs the trainer threads into `SFTConfig` + the callback. |
| 30 |
|
| 31 |
`patience`: eval rounds without improvement before stopping. |
| 32 |
`threshold`: minimum `metric_for_best_model` delta that counts as |
| 33 |
improvement. 0.0 means any improvement resets the patience counter. |
| 34 |
`metric`: HF metric name (`"eval_loss"` by default; the |
| 35 |
`compute_metrics` hook also emits `"eval_perplexity"`). |
| 36 |
""" |
| 37 |
|
| 38 |
patience: int = 3 |
| 39 |
threshold: float = 0.0 |
| 40 |
metric: str = "eval_loss" |
| 41 |
greater_is_better: bool = False |
| 42 |
|
| 43 |
def __post_init__(self) -> None: |
| 44 |
if self.patience < 1: |
| 45 |
raise ValueError(f"patience must be >= 1, got {self.patience}") |
| 46 |
if self.threshold < 0.0: |
| 47 |
raise ValueError(f"threshold must be >= 0.0, got {self.threshold}") |
| 48 |
if not self.metric: |
| 49 |
raise ValueError("metric must be a non-empty string") |
| 50 |
|
| 51 |
|
| 52 |
def build_callback(cfg: EarlyStopConfig) -> Any: |
| 53 |
"""Instantiate an HF `EarlyStoppingCallback` from this config.""" |
| 54 |
from transformers import EarlyStoppingCallback |
| 55 |
|
| 56 |
return EarlyStoppingCallback( |
| 57 |
early_stopping_patience=cfg.patience, |
| 58 |
early_stopping_threshold=cfg.threshold, |
| 59 |
) |
| 60 |
|
| 61 |
|
| 62 |
def was_early_stopped( |
| 63 |
*, max_steps_ran: int, configured_max_steps: int | None, num_epochs_done: float |
| 64 |
) -> bool: |
| 65 |
"""Best-effort detection of early-stop vs. normal completion. |
| 66 |
|
| 67 |
HF sets `trainer.state.global_step == max_steps` on natural end |
| 68 |
(when `max_steps > 0`) or completes the full epoch schedule. If the |
| 69 |
trainer exited before either of those, early-stop is the most |
| 70 |
likely reason. |
| 71 |
|
| 72 |
This is intentionally imprecise — `trainer.state` doesn't expose |
| 73 |
an explicit "early stopped" flag — but the heuristic is right in |
| 74 |
the normal case and harmless in the ambiguous case (the summary |
| 75 |
just records `early_stopped=False`, which is conservative). |
| 76 |
""" |
| 77 |
if configured_max_steps is not None and configured_max_steps > 0: |
| 78 |
return max_steps_ran < configured_max_steps |
| 79 |
# No max_steps cap → we're running to num_epochs. A non-integer |
| 80 |
# `num_epochs_done` means we exited mid-epoch, which is the |
| 81 |
# early-stop signal. |
| 82 |
return not float(num_epochs_done).is_integer() |