Python · 7437 bytes Raw Blame History
1 """Post-training val-loss split helper (audit-08 N9)."""
2
3 from __future__ import annotations
4
5 import logging
6 from typing import Any
7 from unittest.mock import MagicMock
8
9 import pytest
10
11 from dlm.eval.mode_split import _safe_eval_loss, compute_val_loss_by_mode
12
13
14 class _FakeDataset:
15 """Minimal Dataset-shaped stub: list iteration + `.select(indices)`."""
16
17 def __init__(self, rows: list[dict[str, Any]]) -> None:
18 self._rows = rows
19
20 def __len__(self) -> int:
21 return len(self._rows)
22
23 def __iter__(self): # type: ignore[no-untyped-def]
24 return iter(self._rows)
25
26 def select(self, indices: list[int]) -> _FakeDataset:
27 return _FakeDataset([self._rows[i] for i in indices])
28
29
30 def _trainer_with_fixed_losses(
31 cpt_loss: float | None = 0.5,
32 sft_loss: float | None = 0.3,
33 ) -> MagicMock:
34 """Mock: .evaluate() returns different eval_loss per subset size.
35
36 We use subset length as the key so we can identify which mode was
37 queried without inspecting row contents. The test asserts both
38 arms fire on the right subsets.
39 """
40 trainer = MagicMock()
41 call_log: list[int] = []
42
43 def _evaluate(*, eval_dataset: _FakeDataset) -> dict[str, float]:
44 n = len(eval_dataset)
45 call_log.append(n)
46 # Small trick: losses parameterized by call order so both
47 # get exercised; we don't care which is which as long as the
48 # helper returns them in (cpt, sft) order.
49 if cpt_loss is not None and len(call_log) == 1:
50 return {"eval_loss": cpt_loss}
51 if sft_loss is not None:
52 return {"eval_loss": sft_loss}
53 return {}
54
55 trainer.evaluate.side_effect = _evaluate
56 trainer._call_log = call_log # noqa: SLF001
57 return trainer
58
59
60 class TestEmptyOrMissing:
61 def test_none_val_ds_returns_both_none(self) -> None:
62 trainer = MagicMock()
63 assert compute_val_loss_by_mode(trainer, None) == (None, None)
64 trainer.evaluate.assert_not_called()
65
66 def test_empty_val_ds_returns_both_none(self) -> None:
67 trainer = MagicMock()
68 assert compute_val_loss_by_mode(trainer, _FakeDataset([])) == (None, None)
69 trainer.evaluate.assert_not_called()
70
71 def test_non_sized_dataset_returns_both_none(self) -> None:
72 trainer = MagicMock()
73 assert compute_val_loss_by_mode(trainer, _NonSizedDataset([{"text": "prose"}])) == (
74 None,
75 None,
76 )
77 trainer.evaluate.assert_not_called()
78
79
80 class TestModeClassification:
81 def test_only_cpt_rows(self) -> None:
82 trainer = _trainer_with_fixed_losses(cpt_loss=0.7, sft_loss=None)
83 val = _FakeDataset([{"text": "prose a"}, {"text": "prose b"}, {"text": "prose c"}])
84 cpt, sft = compute_val_loss_by_mode(trainer, val)
85 assert cpt == 0.7
86 assert sft is None
87 # Only CPT subset was evaluated.
88 assert trainer._call_log == [3]
89
90 def test_only_sft_rows(self) -> None:
91 trainer = _trainer_with_fixed_losses(cpt_loss=None, sft_loss=0.4)
92 val = _FakeDataset(
93 [
94 {"messages": [{"role": "user", "content": "hi"}]},
95 {"messages": [{"role": "user", "content": "hi"}]},
96 ]
97 )
98 cpt, sft = compute_val_loss_by_mode(trainer, val)
99 assert cpt is None
100 assert sft == 0.4
101
102 def test_mixed_rows(self) -> None:
103 trainer = _trainer_with_fixed_losses(cpt_loss=0.9, sft_loss=0.5)
104 val = _FakeDataset(
105 [
106 {"text": "prose"},
107 {"messages": []},
108 {"text": "more prose"},
109 {"messages": []},
110 ]
111 )
112 cpt, sft = compute_val_loss_by_mode(trainer, val)
113 assert cpt == 0.9
114 assert sft == 0.5
115 # Both subsets evaluated; sizes 2 each.
116 assert sorted(trainer._call_log) == [2, 2]
117
118 def test_preference_rows_skipped(self) -> None:
119 """Preference triples aren't part of CPT or SFT — they shouldn't
120 inflate either subset."""
121 trainer = _trainer_with_fixed_losses(cpt_loss=0.1, sft_loss=None)
122 val = _FakeDataset(
123 [
124 {"text": "prose"},
125 {"prompt": "q", "chosen": "c", "rejected": "r"},
126 {"prompt": "q", "chosen": "c", "rejected": "r"},
127 ]
128 )
129 cpt, sft = compute_val_loss_by_mode(trainer, val)
130 assert cpt == 0.1
131 assert sft is None
132 assert trainer._call_log == [1] # only the one CPT row
133
134
135 class TestEvalFailures:
136 def test_evaluate_exception_yields_none(
137 self,
138 caplog: pytest.LogCaptureFixture,
139 ) -> None:
140 """A stack-version skew that makes evaluate() raise shouldn't
141 crash training — the affected mode just stays None."""
142 caplog.set_level(logging.WARNING, logger="dlm.eval.mode_split")
143 trainer = MagicMock()
144 trainer.evaluate.side_effect = RuntimeError("TRL drift")
145 val = _FakeDataset([{"text": "a"}, {"messages": []}])
146 cpt, sft = compute_val_loss_by_mode(trainer, val)
147 assert cpt is None
148 assert sft is None
149 assert "val-loss split skipped cpt evaluation" in caplog.text
150 assert "val-loss split skipped sft evaluation" in caplog.text
151
152 def test_missing_eval_loss_key_yields_none(self) -> None:
153 trainer = MagicMock()
154 trainer.evaluate.return_value = {"other_metric": 1.0}
155 val = _FakeDataset([{"text": "a"}])
156 cpt, sft = compute_val_loss_by_mode(trainer, val)
157 assert cpt is None
158 assert sft is None
159
160 def test_non_numeric_eval_loss_yields_none(self) -> None:
161 trainer = MagicMock()
162 trainer.evaluate.return_value = {"eval_loss": object()}
163 val = _FakeDataset([{"text": "a"}])
164 cpt, sft = compute_val_loss_by_mode(trainer, val)
165 assert cpt is None
166 assert sft is None
167
168 def test_select_failure_yields_none(
169 self,
170 caplog: pytest.LogCaptureFixture,
171 ) -> None:
172 caplog.set_level(logging.WARNING, logger="dlm.eval.mode_split")
173 trainer = MagicMock()
174 trainer.evaluate.return_value = {"eval_loss": 0.0}
175 # Dataset iteration works, but subset selection does not.
176 bad_val = _NoSelectDataset([{"text": "a"}])
177 cpt, sft = compute_val_loss_by_mode(trainer, bad_val)
178 # Both None — the helper couldn't build subsets.
179 assert cpt is None
180 assert sft is None
181 assert "val-loss split skipped cpt subset selection" in caplog.text
182
183
184 class _NoSelectDataset:
185 def __init__(self, rows: list[dict[str, Any]]) -> None:
186 self._rows = rows
187
188 def __len__(self) -> int:
189 return len(self._rows)
190
191 def __iter__(self): # type: ignore[no-untyped-def]
192 return iter(self._rows)
193
194
195 class _NonSizedDataset:
196 def __init__(self, rows: list[dict[str, Any]]) -> None:
197 self._rows = rows
198
199 def __iter__(self): # type: ignore[no-untyped-def]
200 return iter(self._rows)
201
202
203 def test_safe_eval_loss_value_error_yields_none(
204 caplog: pytest.LogCaptureFixture,
205 ) -> None:
206 caplog.set_level(logging.WARNING, logger="dlm.eval.mode_split")
207 trainer = MagicMock()
208 trainer.evaluate.side_effect = ValueError("bad eval")
209 val = _FakeDataset([{"text": "a"}])
210
211 assert _safe_eval_loss(trainer, val, [0], mode="cpt") is None
212 assert "val-loss split skipped cpt evaluation" in caplog.text