tenseleyflow/documentlanguagemodel / 6e42964

Browse files

test(train,io): unit tests for trainer helpers + _encode_crockford error path

Rather than pragma away the unit-testable helpers in trainer.py,
exercise them directly: _maybe_float, _utc_naive, _sample_replay_rows
(cold + warm + k-floor + determinism), _next_run_id (missing manifest /
empty runs / prior runs), _append_training_run (inside-store relative
vs outside-store absolute fallback), _snapshot_training_state (scaler
present vs absent). Also cover _encode_crockford's 16-byte validator
which had no direct test.
Authored by mfwolffe <wolffemf@dukes.jmu.edu>
SHA
6e4296473a526e2012ead895a2274c893a3f6626
Parents
072257a
Tree
6b1a115

2 changed files

StatusFile+-
M tests/unit/test_io_ulid.py 13 0
A tests/unit/train/test_trainer_helpers.py 310 0
tests/unit/test_io_ulid.pymodified
@@ -50,6 +50,19 @@ class TestSortability:
5050
         assert a[:10] <= b[:10]
5151
 
5252
 
53
+class TestInternals:
54
+    def test_wrong_payload_size_raises(self) -> None:
55
+        """`_encode_crockford` rejects anything other than exactly 16 bytes."""
56
+        import pytest
57
+
58
+        from dlm.io.ulid import _encode_crockford
59
+
60
+        with pytest.raises(ValueError, match="16 bytes"):
61
+            _encode_crockford(b"\x00" * 15)
62
+        with pytest.raises(ValueError, match="16 bytes"):
63
+            _encode_crockford(b"\x00" * 17)
64
+
65
+
5366
 class TestValidatorCompat:
5467
     def test_accepted_by_frontmatter_validator(self) -> None:
5568
         """Round-trip: mint → validate against the schema's ULID regex."""
tests/unit/train/test_trainer_helpers.pyadded
@@ -0,0 +1,310 @@
1
+"""Unit tests for `trainer.py` private helpers (Sprint 13 coverage pass).
2
+
3
+These helpers were under-covered because the public `run()` orchestrator
4
+requires a real HF model, which only the slow integration test can
5
+provide. The helpers themselves are pure Python / pydantic and worth
6
+testing directly.
7
+"""
8
+
9
+from __future__ import annotations
10
+
11
+from pathlib import Path
12
+
13
+from dlm.train.trainer import (
14
+    _append_training_run,
15
+    _maybe_float,
16
+    _next_run_id,
17
+    _sample_replay_rows,
18
+    _utc_naive,
19
+)
20
+
21
+# --- _maybe_float -----------------------------------------------------------
22
+
23
+
24
+class TestMaybeFloat:
25
+    def test_none_returns_none(self) -> None:
26
+        assert _maybe_float(None) is None
27
+
28
+    def test_numeric_returns_float(self) -> None:
29
+        assert _maybe_float(3) == 3.0
30
+        assert _maybe_float(2.5) == 2.5
31
+
32
+    def test_string_numeric_parses(self) -> None:
33
+        assert _maybe_float("1.25") == 1.25
34
+
35
+    def test_bad_string_returns_none(self) -> None:
36
+        assert _maybe_float("not a number") is None
37
+
38
+    def test_invalid_type_returns_none(self) -> None:
39
+        assert _maybe_float(object()) is None
40
+
41
+
42
+# --- _utc_naive -------------------------------------------------------------
43
+
44
+
45
+class TestUtcNaive:
46
+    def test_is_naive(self) -> None:
47
+        ts = _utc_naive()
48
+        assert ts.tzinfo is None
49
+
50
+    def test_microseconds_zeroed(self) -> None:
51
+        ts = _utc_naive()
52
+        assert ts.microsecond == 0
53
+
54
+
55
+# --- _sample_replay_rows ----------------------------------------------------
56
+
57
+
58
+class _FakeChangeSet:
59
+    def __init__(self, new_count: int) -> None:
60
+        self.new = [object() for _ in range(new_count)]
61
+
62
+
63
+class _EmptyReplay:
64
+    def load(self) -> list[object]:
65
+        return []
66
+
67
+    def sample_rows(self, *, k: int, now: object, rng: object) -> list[dict[str, object]]:
68
+        raise AssertionError("should not sample when empty")
69
+
70
+
71
+class _WarmReplay:
72
+    def __init__(self, entries: int = 10) -> None:
73
+        self._entries = [f"entry-{i}" for i in range(entries)]
74
+        self.last_k: int | None = None
75
+
76
+    def load(self) -> list[str]:
77
+        return list(self._entries)
78
+
79
+    def sample_rows(self, *, k: int, now: object, rng: object) -> list[dict[str, object]]:
80
+        self.last_k = k
81
+        return [{"row": i} for i in range(min(k, len(self._entries)))]
82
+
83
+
84
+class TestSampleReplayRows:
85
+    def test_cold_corpus_returns_empty(self) -> None:
86
+        replay = _EmptyReplay()
87
+        out = _sample_replay_rows(
88
+            replay,  # type: ignore[arg-type]
89
+            change_set=_FakeChangeSet(5),  # type: ignore[arg-type]
90
+            seed=42,
91
+            adapter_version=1,
92
+        )
93
+        assert out == []
94
+
95
+    def test_warm_corpus_samples_k_equals_2x_new_floor_32(self) -> None:
96
+        replay = _WarmReplay(entries=200)
97
+        out = _sample_replay_rows(
98
+            replay,  # type: ignore[arg-type]
99
+            change_set=_FakeChangeSet(100),  # type: ignore[arg-type]
100
+            seed=42,
101
+            adapter_version=1,
102
+        )
103
+        # k = max(32, 2 * 100) = 200; replay has 200 entries so all returned.
104
+        assert replay.last_k == 200
105
+        assert len(out) == 200
106
+
107
+    def test_small_change_set_uses_min_k_of_32(self) -> None:
108
+        replay = _WarmReplay(entries=100)
109
+        _sample_replay_rows(
110
+            replay,  # type: ignore[arg-type]
111
+            change_set=_FakeChangeSet(0),  # |new| = 0 → k = max(32, 0) = 32
112
+            seed=0,
113
+            adapter_version=1,
114
+        )
115
+        assert replay.last_k == 32
116
+
117
+    def test_deterministic_across_calls(self) -> None:
118
+        """Same (seed, adapter_version) → same RNG state per call."""
119
+        replay1 = _WarmReplay(entries=50)
120
+        replay2 = _WarmReplay(entries=50)
121
+
122
+        # Both use seed=7, adapter_version=3. The RNG seeds to 10, so
123
+        # both sample_rows calls receive an equal-state Random instance.
124
+        _sample_replay_rows(
125
+            replay1,  # type: ignore[arg-type]
126
+            change_set=_FakeChangeSet(5),  # type: ignore[arg-type]
127
+            seed=7,
128
+            adapter_version=3,
129
+        )
130
+        _sample_replay_rows(
131
+            replay2,  # type: ignore[arg-type]
132
+            change_set=_FakeChangeSet(5),  # type: ignore[arg-type]
133
+            seed=7,
134
+            adapter_version=3,
135
+        )
136
+        assert replay1.last_k == replay2.last_k
137
+
138
+
139
+# --- _next_run_id + _append_training_run -----------------------------------
140
+
141
+
142
+def _bootstrap_store(tmp_path: Path) -> object:
143
+    """Make a minimal StorePath with a valid manifest for helper tests."""
144
+    from dlm.store.manifest import Manifest, save_manifest
145
+    from dlm.store.paths import for_dlm
146
+
147
+    home = tmp_path / "dlm-home"
148
+    store = for_dlm("01HZ4X7TGZM3J1A2B3C4D5E6F7", home=home)
149
+    store.ensure_layout()
150
+    save_manifest(store.manifest, Manifest(dlm_id=store.root.name, base_model="smollm2-135m"))
151
+    return store
152
+
153
+
154
+class TestNextRunId:
155
+    def test_missing_manifest_returns_1(self, tmp_path: Path) -> None:
156
+        """Edge case: manifest not yet written → fresh run."""
157
+        from dlm.store.paths import for_dlm
158
+
159
+        home = tmp_path / "dlm-home"
160
+        store = for_dlm("01HZ4X7TGZM3J1A2B3C4D5E6F7", home=home)
161
+        # Don't ensure_layout / save_manifest — leave manifest missing.
162
+        assert _next_run_id(store) == 1
163
+
164
+    def test_empty_training_runs_returns_1(self, tmp_path: Path) -> None:
165
+        store = _bootstrap_store(tmp_path)
166
+        assert _next_run_id(store) == 1  # type: ignore[arg-type]
167
+
168
+    def test_with_prior_runs_returns_max_plus_one(self, tmp_path: Path) -> None:
169
+        from dlm.store.manifest import TrainingRunSummary, load_manifest, save_manifest
170
+
171
+        store = _bootstrap_store(tmp_path)
172
+        manifest = load_manifest(store.manifest)  # type: ignore[attr-defined]
173
+        updated = manifest.model_copy(
174
+            update={
175
+                "training_runs": [
176
+                    TrainingRunSummary(
177
+                        run_id=1, started_at=_utc_naive(), adapter_version=1, seed=0
178
+                    ),
179
+                    TrainingRunSummary(
180
+                        run_id=5, started_at=_utc_naive(), adapter_version=1, seed=0
181
+                    ),
182
+                ],
183
+            }
184
+        )
185
+        save_manifest(store.manifest, updated)  # type: ignore[attr-defined]
186
+        assert _next_run_id(store) == 6  # type: ignore[arg-type]
187
+
188
+
189
+class TestAppendTrainingRun:
190
+    def test_summary_path_outside_store_recorded_absolute(self, tmp_path: Path) -> None:
191
+        """The relative_to() ValueError branch: fallback to absolute path."""
192
+        from dlm.store.manifest import load_manifest
193
+
194
+        store = _bootstrap_store(tmp_path)
195
+        # A path that can't be made relative to store.root.
196
+        outside = tmp_path / "outside" / "summary.json"
197
+        outside.parent.mkdir(parents=True, exist_ok=True)
198
+        outside.touch()
199
+
200
+        _append_training_run(
201
+            store=store,  # type: ignore[arg-type]
202
+            run_id=1,
203
+            adapter_version=1,
204
+            seed=0,
205
+            steps=10,
206
+            final_train_loss=0.5,
207
+            final_val_loss=None,
208
+            base_model_revision="deadbeef",
209
+            versions={"torch": "2.4.0"},
210
+            current_sections=[],
211
+            summary_path=outside,
212
+        )
213
+
214
+        manifest = load_manifest(store.manifest)  # type: ignore[attr-defined]
215
+        assert len(manifest.training_runs) == 1
216
+        recorded = manifest.training_runs[0].summary_path
217
+        # Outside-store path is absolute (matches the input).
218
+        assert recorded == str(outside)
219
+
220
+    def test_summary_path_under_store_recorded_relative(self, tmp_path: Path) -> None:
221
+        from dlm.store.manifest import load_manifest
222
+
223
+        store = _bootstrap_store(tmp_path)
224
+        # A path inside the store.
225
+        store.logs.mkdir(parents=True, exist_ok=True)  # type: ignore[attr-defined]
226
+        inside = store.logs / "summary.json"  # type: ignore[attr-defined]
227
+        inside.touch()
228
+
229
+        _append_training_run(
230
+            store=store,  # type: ignore[arg-type]
231
+            run_id=1,
232
+            adapter_version=1,
233
+            seed=0,
234
+            steps=10,
235
+            final_train_loss=0.5,
236
+            final_val_loss=None,
237
+            base_model_revision="deadbeef",
238
+            versions={"torch": "2.4.0"},
239
+            current_sections=[],
240
+            summary_path=inside,
241
+        )
242
+
243
+        manifest = load_manifest(store.manifest)  # type: ignore[attr-defined]
244
+        assert len(manifest.training_runs) == 1
245
+        recorded = manifest.training_runs[0].summary_path
246
+        # Relative to store root, not absolute.
247
+        assert recorded is not None
248
+        assert not Path(recorded).is_absolute()
249
+
250
+
251
+# --- _snapshot_training_state (scaler path) ---------------------------------
252
+
253
+
254
+class _FakeOptimizer:
255
+    def state_dict(self) -> dict[str, str]:
256
+        return {"opt": "state"}
257
+
258
+
259
+class _FakeScaler:
260
+    def state_dict(self) -> dict[str, str]:
261
+        return {"scaler": "state"}
262
+
263
+
264
+class _FakeState:
265
+    global_step = 42
266
+    epoch = 1.5
267
+    best_metric = None
268
+
269
+
270
+class _FakeSft:
271
+    def __init__(self, with_scaler: bool = False) -> None:
272
+        self.optimizer = _FakeOptimizer()
273
+        self.lr_scheduler = None
274
+        self.state = _FakeState()
275
+        self.scaler = _FakeScaler() if with_scaler else None
276
+
277
+
278
+def _smollm_spec() -> object:
279
+    from dlm.base_models import BASE_MODELS
280
+
281
+    return BASE_MODELS["smollm2-135m"]
282
+
283
+
284
+class TestSnapshotTrainingState:
285
+    def test_captures_scaler_when_present(self) -> None:
286
+        from dlm.train.trainer import _snapshot_training_state
287
+
288
+        sft = _FakeSft(with_scaler=True)
289
+        state = _snapshot_training_state(
290
+            sft,
291
+            spec=_smollm_spec(),  # type: ignore[arg-type]
292
+            versions={"torch": "2.4.0"},
293
+            use_qlora=False,
294
+        )
295
+        assert state["scaler_state_dict"] == {"scaler": "state"}
296
+        assert state["global_step"] == 42
297
+        assert state["use_qlora"] is False
298
+
299
+    def test_no_scaler_leaves_none(self) -> None:
300
+        from dlm.train.trainer import _snapshot_training_state
301
+
302
+        sft = _FakeSft(with_scaler=False)
303
+        state = _snapshot_training_state(
304
+            sft,
305
+            spec=_smollm_spec(),  # type: ignore[arg-type]
306
+            versions={"torch": "2.4.0"},
307
+            use_qlora=True,
308
+        )
309
+        assert state["scaler_state_dict"] is None
310
+        assert state["use_qlora"] is True