Python · 19950 bytes Raw Blame History
1 """Unit tests for `trainer.py` private helpers (Sprint 13 coverage pass).
2
3 These helpers were under-covered because the public `run()` orchestrator
4 requires a real HF model, which only the slow integration test can
5 provide. The helpers themselves are pure Python / pydantic and worth
6 testing directly.
7 """
8
9 from __future__ import annotations
10
11 import logging
12 from pathlib import Path
13 from types import SimpleNamespace
14 from typing import cast
15 from unittest.mock import MagicMock
16
17 import pytest
18
19 from dlm.base_models import BASE_MODELS
20 from dlm.directives import ExpandResult, SourceProvenance
21 from dlm.directives.discovery import DiscoveredConfig
22 from dlm.directives.schema import DlmTrainingConfig
23 from dlm.doc.parser import ParsedDlm
24 from dlm.doc.schema import DlmFrontmatter, SourceDirective, TrainingConfig
25 from dlm.doc.sections import Section, SectionType
26 from dlm.lock import LockDecision, LockSchemaError, Severity
27 from dlm.replay import ChangeSet
28 from dlm.train.trainer import (
29 _append_change_set_to_replay,
30 _append_training_run,
31 _attach_dlm_trainer_callback,
32 _build_candidate_lock,
33 _compute_weight_distribution,
34 _expand_directives,
35 _maybe_float,
36 _maybe_record_tokenization,
37 _next_run_id,
38 _sample_replay_rows,
39 _utc_naive,
40 _validate_or_abort_lock,
41 )
42
43 # --- _maybe_float -----------------------------------------------------------
44
45
46 class TestMaybeFloat:
47 def test_none_returns_none(self) -> None:
48 assert _maybe_float(None) is None
49
50 def test_numeric_returns_float(self) -> None:
51 assert _maybe_float(3) == 3.0
52 assert _maybe_float(2.5) == 2.5
53
54 def test_string_numeric_parses(self) -> None:
55 assert _maybe_float("1.25") == 1.25
56
57 def test_bad_string_returns_none(self) -> None:
58 assert _maybe_float("not a number") is None
59
60 def test_invalid_type_returns_none(self) -> None:
61 assert _maybe_float(object()) is None
62
63
64 # --- _utc_naive -------------------------------------------------------------
65
66
67 class TestUtcNaive:
68 def test_is_naive(self) -> None:
69 ts = _utc_naive()
70 assert ts.tzinfo is None
71
72 def test_microseconds_zeroed(self) -> None:
73 ts = _utc_naive()
74 assert ts.microsecond == 0
75
76
77 # --- _sample_replay_rows ----------------------------------------------------
78
79
80 def _fake_change_set(new_count: int) -> ChangeSet:
81 return ChangeSet(
82 new=[Section(type=SectionType.PROSE, content=f"row {i}") for i in range(new_count)]
83 )
84
85
86 class _EmptyReplay:
87 def load(self) -> list[object]:
88 return []
89
90 def sample_rows(self, *, k: int, now: object, rng: object) -> list[dict[str, object]]:
91 raise AssertionError("should not sample when empty")
92
93
94 class _WarmReplay:
95 def __init__(self, entries: int = 10) -> None:
96 self._entries = [f"entry-{i}" for i in range(entries)]
97 self.last_k: int | None = None
98
99 def load(self) -> list[str]:
100 return list(self._entries)
101
102 def sample_rows(self, *, k: int, now: object, rng: object) -> list[dict[str, object]]:
103 self.last_k = k
104 return [{"row": i} for i in range(min(k, len(self._entries)))]
105
106
107 class TestSampleReplayRows:
108 def test_cold_corpus_returns_empty(self) -> None:
109 replay = _EmptyReplay()
110 out = _sample_replay_rows(
111 replay, # type: ignore[arg-type]
112 change_set=_fake_change_set(5),
113 seed=42,
114 adapter_version=1,
115 )
116 assert out == []
117
118 def test_warm_corpus_samples_k_equals_2x_new_floor_32(self) -> None:
119 replay = _WarmReplay(entries=200)
120 out = _sample_replay_rows(
121 replay, # type: ignore[arg-type]
122 change_set=_fake_change_set(100),
123 seed=42,
124 adapter_version=1,
125 )
126 # k = max(32, 2 * 100) = 200; replay has 200 entries so all returned.
127 assert replay.last_k == 200
128 assert len(out) == 200
129
130 def test_small_change_set_uses_min_k_of_32(self) -> None:
131 replay = _WarmReplay(entries=100)
132 _sample_replay_rows(
133 replay, # type: ignore[arg-type]
134 change_set=_fake_change_set(0), # |new| = 0 → k = max(32, 0) = 32
135 seed=0,
136 adapter_version=1,
137 )
138 assert replay.last_k == 32
139
140 def test_deterministic_across_calls(self) -> None:
141 """Same (seed, adapter_version) → same RNG state per call."""
142 replay1 = _WarmReplay(entries=50)
143 replay2 = _WarmReplay(entries=50)
144
145 # Both use seed=7, adapter_version=3. The RNG seeds to 10, so
146 # both sample_rows calls receive an equal-state Random instance.
147 _sample_replay_rows(
148 replay1, # type: ignore[arg-type]
149 change_set=_fake_change_set(5),
150 seed=7,
151 adapter_version=3,
152 )
153 _sample_replay_rows(
154 replay2, # type: ignore[arg-type]
155 change_set=_fake_change_set(5),
156 seed=7,
157 adapter_version=3,
158 )
159 assert replay1.last_k == replay2.last_k
160
161
162 # --- _next_run_id + _append_training_run -----------------------------------
163
164
165 def _bootstrap_store(tmp_path: Path) -> object:
166 """Make a minimal StorePath with a valid manifest for helper tests."""
167 from dlm.store.manifest import Manifest, save_manifest
168 from dlm.store.paths import for_dlm
169
170 home = tmp_path / "dlm-home"
171 store = for_dlm("01HZ4X7TGZM3J1A2B3C4D5E6F7", home=home)
172 store.ensure_layout()
173 save_manifest(store.manifest, Manifest(dlm_id=store.root.name, base_model="smollm2-135m"))
174 return store
175
176
177 _SOURCE_PATH_SENTINEL = object()
178
179
180 def _parsed(
181 tmp_path: Path,
182 *,
183 source_path: object = _SOURCE_PATH_SENTINEL,
184 sections: tuple[Section, ...] | None = None,
185 sources: tuple[SourceDirective, ...] | None = None,
186 ) -> ParsedDlm:
187 resolved_source_path: Path | None
188 if source_path is _SOURCE_PATH_SENTINEL:
189 resolved_source_path = tmp_path / "doc.dlm"
190 resolved_source_path.write_text("placeholder .dlm body\n", encoding="utf-8")
191 else:
192 assert source_path is None or isinstance(source_path, Path)
193 resolved_source_path = source_path
194 return ParsedDlm(
195 frontmatter=DlmFrontmatter(
196 dlm_id="01HZ4X7TGZM3J1A2B3C4D5E6F7",
197 base_model="smollm2-135m",
198 training=TrainingConfig(seed=42, sources=sources),
199 ),
200 sections=sections or (Section(type=SectionType.PROSE, content="x"),),
201 source_path=resolved_source_path,
202 )
203
204
205 class TestNextRunId:
206 def test_missing_manifest_returns_1(self, tmp_path: Path) -> None:
207 """Edge case: manifest not yet written → fresh run."""
208 from dlm.store.paths import for_dlm
209
210 home = tmp_path / "dlm-home"
211 store = for_dlm("01HZ4X7TGZM3J1A2B3C4D5E6F7", home=home)
212 # Don't ensure_layout / save_manifest — leave manifest missing.
213 assert _next_run_id(store) == 1
214
215 def test_empty_training_runs_returns_1(self, tmp_path: Path) -> None:
216 store = _bootstrap_store(tmp_path)
217 assert _next_run_id(store) == 1 # type: ignore[arg-type]
218
219 def test_with_prior_runs_returns_max_plus_one(self, tmp_path: Path) -> None:
220 from dlm.store.manifest import TrainingRunSummary, load_manifest, save_manifest
221
222 store = _bootstrap_store(tmp_path)
223 manifest = load_manifest(store.manifest) # type: ignore[attr-defined]
224 updated = manifest.model_copy(
225 update={
226 "training_runs": [
227 TrainingRunSummary(
228 run_id=1, started_at=_utc_naive(), adapter_version=1, seed=0
229 ),
230 TrainingRunSummary(
231 run_id=5, started_at=_utc_naive(), adapter_version=1, seed=0
232 ),
233 ],
234 }
235 )
236 save_manifest(store.manifest, updated) # type: ignore[attr-defined]
237 assert _next_run_id(store) == 6 # type: ignore[arg-type]
238
239
240 class TestAppendTrainingRun:
241 def test_summary_path_outside_store_recorded_absolute(self, tmp_path: Path) -> None:
242 """The relative_to() ValueError branch: fallback to absolute path."""
243 from dlm.store.manifest import load_manifest
244
245 store = _bootstrap_store(tmp_path)
246 # A path that can't be made relative to store.root.
247 outside = tmp_path / "outside" / "summary.json"
248 outside.parent.mkdir(parents=True, exist_ok=True)
249 outside.touch()
250
251 _append_training_run(
252 store=store, # type: ignore[arg-type]
253 run_id=1,
254 adapter_version=1,
255 seed=0,
256 steps=10,
257 final_train_loss=0.5,
258 final_val_loss=None,
259 base_model_revision="deadbeef",
260 versions={"torch": "2.4.0"},
261 current_sections=[],
262 summary_path=outside,
263 )
264
265 manifest = load_manifest(store.manifest) # type: ignore[attr-defined]
266 assert len(manifest.training_runs) == 1
267 recorded = manifest.training_runs[0].summary_path
268 # Outside-store path is absolute (matches the input).
269 assert recorded == str(outside)
270
271 def test_summary_path_under_store_recorded_relative(self, tmp_path: Path) -> None:
272 from dlm.store.manifest import load_manifest
273
274 store = _bootstrap_store(tmp_path)
275 # A path inside the store.
276 store.logs.mkdir(parents=True, exist_ok=True) # type: ignore[attr-defined]
277 inside = store.logs / "summary.json" # type: ignore[attr-defined]
278 inside.touch()
279
280 _append_training_run(
281 store=store, # type: ignore[arg-type]
282 run_id=1,
283 adapter_version=1,
284 seed=0,
285 steps=10,
286 final_train_loss=0.5,
287 final_val_loss=None,
288 base_model_revision="deadbeef",
289 versions={"torch": "2.4.0"},
290 current_sections=[],
291 summary_path=inside,
292 )
293
294 manifest = load_manifest(store.manifest) # type: ignore[attr-defined]
295 assert len(manifest.training_runs) == 1
296 recorded = manifest.training_runs[0].summary_path
297 # Relative to store root, not absolute.
298 assert recorded is not None
299 assert not Path(recorded).is_absolute()
300
301
302 # --- _snapshot_training_state (scaler path) ---------------------------------
303
304
305 class _FakeOptimizer:
306 def state_dict(self) -> dict[str, str]:
307 return {"opt": "state"}
308
309
310 class _FakeScaler:
311 def state_dict(self) -> dict[str, str]:
312 return {"scaler": "state"}
313
314
315 class _FakeState:
316 global_step = 42
317 epoch = 1.5
318 best_metric = None
319
320
321 class _FakeSft:
322 def __init__(self, with_scaler: bool = False) -> None:
323 self.optimizer = _FakeOptimizer()
324 self.lr_scheduler = None
325 self.state = _FakeState()
326 self.scaler = _FakeScaler() if with_scaler else None
327
328
329 def _smollm_spec() -> object:
330 from dlm.base_models import BASE_MODELS
331
332 return BASE_MODELS["smollm2-135m"]
333
334
335 class TestSnapshotTrainingState:
336 def test_captures_scaler_when_present(self) -> None:
337 from dlm.train.trainer import _snapshot_training_state
338
339 sft = _FakeSft(with_scaler=True)
340 state = _snapshot_training_state(
341 sft,
342 spec=_smollm_spec(), # type: ignore[arg-type]
343 versions={"torch": "2.4.0"},
344 use_qlora=False,
345 )
346 assert state["scaler_state_dict"] == {"scaler": "state"}
347 assert state["global_step"] == 42
348 assert state["use_qlora"] is False
349
350 def test_no_scaler_leaves_none(self) -> None:
351 from dlm.train.trainer import _snapshot_training_state
352
353 sft = _FakeSft(with_scaler=False)
354 state = _snapshot_training_state(
355 sft,
356 spec=_smollm_spec(), # type: ignore[arg-type]
357 versions={"torch": "2.4.0"},
358 use_qlora=True,
359 )
360 assert state["scaler_state_dict"] is None
361 assert state["use_qlora"] is True
362
363
364 class TestAttachDlmTrainerCallback:
365 def test_returns_when_trainer_has_no_add_callback(self) -> None:
366 _attach_dlm_trainer_callback(
367 trainer=SimpleNamespace(),
368 recorder=MagicMock(),
369 run_id=1,
370 step_logger=MagicMock(),
371 )
372
373 def test_warns_and_swallows_callback_attachment_errors(
374 self,
375 caplog: pytest.LogCaptureFixture,
376 ) -> None:
377 caplog.set_level(logging.WARNING, logger="dlm.train.trainer")
378 trainer = SimpleNamespace(add_callback=MagicMock(side_effect=RuntimeError("boom")))
379
380 _attach_dlm_trainer_callback(
381 trainer=trainer,
382 recorder=MagicMock(),
383 run_id=1,
384 step_logger=MagicMock(),
385 )
386
387 assert "failed to attach DlmTrainerCallback" in caplog.text
388
389
390 class TestMaybeRecordTokenization:
391 def test_missing_trainer_stats_is_a_no_op(self) -> None:
392 recorder = MagicMock()
393
394 _maybe_record_tokenization(
395 recorder=recorder,
396 run_id=1,
397 trainer=SimpleNamespace(),
398 )
399
400 recorder.record_tokenization.assert_not_called()
401
402
403 class TestAppendChangeSetToReplay:
404 def test_all_media_change_set_does_not_append(self) -> None:
405 replay = MagicMock()
406 change_set = SimpleNamespace(
407 new=[
408 Section(type=SectionType.IMAGE, content="", media_path="hero.png"),
409 Section(
410 type=SectionType.AUDIO,
411 content="",
412 media_path="clip.wav",
413 media_transcript="spoken transcript",
414 ),
415 ]
416 )
417
418 _append_change_set_to_replay(
419 replay,
420 cast(ChangeSet, change_set),
421 run_id=7,
422 )
423
424 replay.append_many.assert_not_called()
425
426
427 class TestBuildCandidateLock:
428 def test_requires_source_path(self, tmp_path: Path) -> None:
429 parsed = _parsed(tmp_path, source_path=None)
430
431 with pytest.raises(ValueError, match="source_path is required"):
432 _build_candidate_lock(
433 parsed=parsed,
434 spec=BASE_MODELS["smollm2-135m"],
435 seed=42,
436 run_id=1,
437 versions={"torch": "2.4.0"},
438 determinism_class="strict",
439 capabilities=None,
440 )
441
442
443 class TestValidateOrAbortLock:
444 def test_default_mode_reraises_unreadable_prior_lock(self, tmp_path: Path) -> None:
445 store = _bootstrap_store(tmp_path)
446 parsed = _parsed(tmp_path)
447 (store.root / "dlm.lock").write_text("{not json", encoding="utf-8") # type: ignore[attr-defined]
448
449 with pytest.raises(LockSchemaError):
450 _validate_or_abort_lock(
451 store=store, # type: ignore[arg-type]
452 parsed=parsed,
453 spec=BASE_MODELS["smollm2-135m"],
454 seed=42,
455 run_id=1,
456 versions={"torch": "2.4.0"},
457 determinism_class="strict",
458 capabilities=None,
459 lock_mode="default",
460 )
461
462 def test_logs_warning_mismatches_when_validator_allows_proceed(
463 self,
464 tmp_path: Path,
465 monkeypatch: pytest.MonkeyPatch,
466 caplog: pytest.LogCaptureFixture,
467 ) -> None:
468 import dlm.train.trainer as trainer_mod
469
470 store = _bootstrap_store(tmp_path)
471 parsed = _parsed(tmp_path)
472 decision = LockDecision(
473 action="proceed_with_warnings",
474 mismatches=[(Severity.WARN, "torch minor-version drift")],
475 should_write_lock=True,
476 )
477 monkeypatch.setattr(trainer_mod, "load_lock", lambda _root: object())
478 monkeypatch.setattr(
479 trainer_mod,
480 "validate_lock",
481 lambda _prior, _candidate, mode="default": decision,
482 )
483 caplog.set_level(logging.WARNING, logger="dlm.train.trainer")
484
485 got = _validate_or_abort_lock(
486 store=store, # type: ignore[arg-type]
487 parsed=parsed,
488 spec=BASE_MODELS["smollm2-135m"],
489 seed=42,
490 run_id=1,
491 versions={"torch": "2.4.0"},
492 determinism_class="strict",
493 capabilities=None,
494 lock_mode="default",
495 )
496
497 assert got == decision
498 assert "dlm.lock drift: torch minor-version drift" in caplog.text
499
500
501 class TestComputeWeightDistribution:
502 def test_counts_rows_when_directive_weights_are_active(self, tmp_path: Path) -> None:
503 parsed = _parsed(
504 tmp_path,
505 sections=(Section(type=SectionType.PROSE, content="note", tags={"kind": "note"}),),
506 )
507 discovered = (
508 DiscoveredConfig(
509 anchor=tmp_path,
510 config=DlmTrainingConfig(weights={"kind": {"note": 2.0}}),
511 ignore_rules=(),
512 ),
513 )
514
515 dist = _compute_weight_distribution(parsed=parsed, directive_discovered=discovered)
516
517 assert dist == {"kind": {"note": 1}}
518
519
520 class TestExpandDirectives:
521 def test_returns_original_parsed_when_expansion_finds_no_sections(
522 self,
523 tmp_path: Path,
524 monkeypatch: pytest.MonkeyPatch,
525 ) -> None:
526 parsed = _parsed(
527 tmp_path,
528 sources=(SourceDirective(path="corpus"),),
529 )
530 discovered = (
531 DiscoveredConfig(
532 anchor=tmp_path,
533 config=DlmTrainingConfig(),
534 ignore_rules=(),
535 ),
536 )
537
538 def _fake_expand_sources(
539 parsed_arg: ParsedDlm,
540 *,
541 base_path: Path,
542 ) -> ExpandResult:
543 assert parsed_arg is parsed
544 assert parsed.source_path is not None
545 assert base_path == parsed.source_path.parent
546 return ExpandResult(
547 sections=(),
548 provenance=(SourceProvenance(path="corpus", file_count=0, total_bytes=0),),
549 discovered=discovered,
550 )
551
552 monkeypatch.setattr("dlm.directives.expand_sources", _fake_expand_sources)
553
554 new_parsed, provenance, got_discovered = _expand_directives(parsed)
555
556 assert new_parsed is parsed
557 assert provenance[0].file_count == 0
558 assert got_discovered == discovered
559
560 def test_falls_back_to_cwd_and_logs_when_sections_expand(
561 self,
562 tmp_path: Path,
563 monkeypatch: pytest.MonkeyPatch,
564 caplog: pytest.LogCaptureFixture,
565 ) -> None:
566 parsed = _parsed(
567 tmp_path,
568 source_path=None,
569 sources=(SourceDirective(path="corpus"),),
570 )
571 captured: dict[str, Path] = {}
572
573 def _fake_expand_sources(
574 parsed_arg: ParsedDlm,
575 *,
576 base_path: Path,
577 ) -> ExpandResult:
578 captured["base_path"] = base_path
579 assert parsed_arg is parsed
580 return ExpandResult(
581 sections=(Section(type=SectionType.PROSE, content="expanded prose"),),
582 provenance=(SourceProvenance(path="corpus", file_count=1, total_bytes=14),),
583 discovered=(
584 DiscoveredConfig(
585 anchor=base_path,
586 config=DlmTrainingConfig(),
587 ignore_rules=(),
588 ),
589 ),
590 )
591
592 monkeypatch.setattr("dlm.directives.expand_sources", _fake_expand_sources)
593 caplog.set_level(logging.INFO, logger="dlm.train.trainer")
594
595 new_parsed, provenance, discovered = _expand_directives(parsed)
596
597 assert captured["base_path"] == Path.cwd()
598 assert len(new_parsed.sections) == len(parsed.sections) + 1
599 assert provenance[0].path == "corpus"
600 assert len(discovered) == 1
601 assert "directives: expanded 1 file(s) across 1 source(s)" in caplog.text