Python · 3182 bytes Raw Blame History
1 """Early stopping — thin wrapper over `transformers.EarlyStoppingCallback`.
2
3 The HF callback itself does the work (monitors `metric_for_best_model`,
4 increments a patience counter, sets `control.should_training_stop` when
5 the counter exceeds `early_stopping_patience`). We wrap it for two
6 reasons:
7
8 1. **Config validation.** Patience must be ≥1; threshold must be
9 finite and ≥0. HF accepts nonsense defaults and silently degrades.
10 2. **Inspection.** After `trainer.train()` returns, we need to know
11 whether the run actually early-stopped (vs. completing normally)
12 so the summary can record it. HF exposes this via
13 `trainer.state.best_metric` + `global_step`; we bundle the check
14 in `was_early_stopped()` so downstream callers don't reach into
15 trainer internals.
16 """
17
18 from __future__ import annotations
19
20 from dataclasses import dataclass
21 from typing import TYPE_CHECKING, Any
22
23 if TYPE_CHECKING:
24 pass
25
26
27 @dataclass(frozen=True)
28 class EarlyStopConfig:
29 """Knobs the trainer threads into `SFTConfig` + the callback.
30
31 `patience`: eval rounds without improvement before stopping.
32 `threshold`: minimum `metric_for_best_model` delta that counts as
33 improvement. 0.0 means any improvement resets the patience counter.
34 `metric`: HF metric name (`"eval_loss"` by default; the
35 `compute_metrics` hook also emits `"eval_perplexity"`).
36 """
37
38 patience: int = 3
39 threshold: float = 0.0
40 metric: str = "eval_loss"
41 greater_is_better: bool = False
42
43 def __post_init__(self) -> None:
44 if self.patience < 1:
45 raise ValueError(f"patience must be >= 1, got {self.patience}")
46 if self.threshold < 0.0:
47 raise ValueError(f"threshold must be >= 0.0, got {self.threshold}")
48 if not self.metric:
49 raise ValueError("metric must be a non-empty string")
50
51
52 def build_callback(cfg: EarlyStopConfig) -> Any:
53 """Instantiate an HF `EarlyStoppingCallback` from this config."""
54 from transformers import EarlyStoppingCallback
55
56 return EarlyStoppingCallback(
57 early_stopping_patience=cfg.patience,
58 early_stopping_threshold=cfg.threshold,
59 )
60
61
62 def was_early_stopped(
63 *, max_steps_ran: int, configured_max_steps: int | None, num_epochs_done: float
64 ) -> bool:
65 """Best-effort detection of early-stop vs. normal completion.
66
67 HF sets `trainer.state.global_step == max_steps` on natural end
68 (when `max_steps > 0`) or completes the full epoch schedule. If the
69 trainer exited before either of those, early-stop is the most
70 likely reason.
71
72 This is intentionally imprecise — `trainer.state` doesn't expose
73 an explicit "early stopped" flag — but the heuristic is right in
74 the normal case and harmless in the ambiguous case (the summary
75 just records `early_stopped=False`, which is conservative).
76 """
77 if configured_max_steps is not None and configured_max_steps > 0:
78 return max_steps_ran < configured_max_steps
79 # No max_steps cap → we're running to num_epochs. A non-integer
80 # `num_epochs_done` means we exited mid-epoch, which is the
81 # early-stop signal.
82 return not float(num_epochs_done).is_integer()