tenseleyflow/documentlanguagemodel / b3e7a82

Browse files

test(eval,inference): 77-test suite — perplexity/early-stop/probes/retention/summary/plan

Authored by mfwolffe <wolffemf@dukes.jmu.edu>
SHA
b3e7a827840d383970306a69aeb5836f31281ba3
Parents
2c53cae
Tree
ace1760

15 changed files

StatusFile+-
M src/dlm/eval/probes.py 15 1
M src/dlm/inference/loader.py 1 2
M src/dlm/inference/plan.py 1 2
A tests/unit/eval/__init__.py 0 0
A tests/unit/eval/test_early_stop.py 60 0
A tests/unit/eval/test_perplexity.py 32 0
A tests/unit/eval/test_probes.py 81 0
A tests/unit/eval/test_retention.py 84 0
A tests/unit/eval/test_summary.py 104 0
A tests/unit/eval/test_val_loss.py 61 0
A tests/unit/inference/__init__.py 0 0
A tests/unit/inference/test_generate.py 63 0
A tests/unit/inference/test_loader.py 50 0
A tests/unit/inference/test_plan.py 149 0
M tests/unit/train/test_trainer.py 4 1
src/dlm/eval/probes.pymodified
@@ -134,6 +134,12 @@ def _auto_sample_probes(
134134
     Hashes `(seed, question)` and keeps the top-k by hash — a stable
135135
     weighted sample without needing `random.Random`. Excludes any
136136
     prompt already in `exclude` (typically explicit probes).
137
+
138
+    Parses the *normalized* section body so sections containing
139
+    `### Q !probe` headers don't trip the strict instruction parser
140
+    — we strip the marker, then filter out `!probe:`-prefixed bodies
141
+    (those are the explicit probes, which the caller has already
142
+    captured).
137143
     """
138144
     if k <= 0:
139145
         return []
@@ -143,10 +149,18 @@ def _auto_sample_probes(
143149
         if section.type is not SectionType.INSTRUCTION:
144150
             continue
145151
         try:
146
-            pairs = parse_instruction_body(section.content, section_id=section.section_id)
152
+            pairs = parse_instruction_body(
153
+                _normalize_probe_markers(section.content),
154
+                section_id=section.section_id,
155
+            )
147156
         except Exception:
148157
             continue
149158
         for pair in pairs:
159
+            # Skip explicit probes (their question body was prefixed
160
+            # with `!probe:` by the normalizer) — the caller handles
161
+            # them separately.
162
+            if pair.question.startswith(f"{_PROBE_MARKER}:"):
163
+                continue
150164
             if pair.question in exclude:
151165
                 continue
152166
             candidates.append(
src/dlm/inference/loader.pymodified
@@ -119,8 +119,7 @@ def load_for_inference( # pragma: no cover
119119
     adapter_path = store.resolve_current_adapter()
120120
     if adapter_path is None or not adapter_path.exists():
121121
         raise AdapterNotFoundError(
122
-            f"no adapter under {store.adapter_current_pointer}; "
123
-            "has `dlm train` run successfully?"
122
+            f"no adapter under {store.adapter_current_pointer}; has `dlm train` run successfully?"
124123
         )
125124
 
126125
     from transformers import AutoModelForCausalLM, AutoTokenizer
src/dlm/inference/plan.pymodified
@@ -97,8 +97,7 @@ def resolve_inference(adapter_dir: Path, caps: Any) -> InferencePlan:
9797
                 dequantize_on_load=True,
9898
                 attn_implementation=_pick_attn(caps),
9999
                 reason=(
100
-                    "QLoRA adapter but bitsandbytes not installed; "
101
-                    "dequantizing to fp16 on load."
100
+                    "QLoRA adapter but bitsandbytes not installed; dequantizing to fp16 on load."
102101
                 ),
103102
             )
104103
         # Plain LoRA on CUDA.
tests/unit/eval/__init__.pyadded
tests/unit/eval/test_early_stop.pyadded
@@ -0,0 +1,60 @@
1
+"""EarlyStopConfig validation + was_early_stopped heuristic."""
2
+
3
+from __future__ import annotations
4
+
5
+import pytest
6
+
7
+from dlm.eval.early_stop import EarlyStopConfig, build_callback, was_early_stopped
8
+
9
+
10
+class TestEarlyStopConfig:
11
+    def test_defaults(self) -> None:
12
+        cfg = EarlyStopConfig()
13
+        assert cfg.patience == 3
14
+        assert cfg.threshold == 0.0
15
+        assert cfg.metric == "eval_loss"
16
+        assert cfg.greater_is_better is False
17
+
18
+    def test_patience_below_one_rejected(self) -> None:
19
+        with pytest.raises(ValueError, match="patience"):
20
+            EarlyStopConfig(patience=0)
21
+
22
+    def test_negative_threshold_rejected(self) -> None:
23
+        with pytest.raises(ValueError, match="threshold"):
24
+            EarlyStopConfig(threshold=-0.1)
25
+
26
+    def test_empty_metric_rejected(self) -> None:
27
+        with pytest.raises(ValueError, match="metric"):
28
+            EarlyStopConfig(metric="")
29
+
30
+
31
+class TestBuildCallback:
32
+    def test_returns_hf_callback(self) -> None:
33
+        cfg = EarlyStopConfig(patience=5, threshold=0.01)
34
+        callback = build_callback(cfg)
35
+        # HF's EarlyStoppingCallback stores these as attributes.
36
+        assert callback.early_stopping_patience == 5
37
+        assert callback.early_stopping_threshold == pytest.approx(0.01)
38
+
39
+
40
+class TestWasEarlyStopped:
41
+    def test_max_steps_hit_exactly_means_not_stopped(self) -> None:
42
+        assert not was_early_stopped(
43
+            max_steps_ran=100, configured_max_steps=100, num_epochs_done=0.4
44
+        )
45
+
46
+    def test_max_steps_not_hit_means_stopped(self) -> None:
47
+        assert was_early_stopped(max_steps_ran=47, configured_max_steps=100, num_epochs_done=0.2)
48
+
49
+    def test_integer_epochs_mean_not_stopped(self) -> None:
50
+        """Natural completion finishes exactly `num_train_epochs`."""
51
+        assert not was_early_stopped(
52
+            max_steps_ran=500, configured_max_steps=None, num_epochs_done=3.0
53
+        )
54
+
55
+    def test_fractional_epoch_means_stopped(self) -> None:
56
+        assert was_early_stopped(max_steps_ran=200, configured_max_steps=None, num_epochs_done=1.47)
57
+
58
+    def test_max_steps_zero_falls_back_to_epoch_check(self) -> None:
59
+        """`max_steps=0` (or negative) means the cap isn't active."""
60
+        assert not was_early_stopped(max_steps_ran=300, configured_max_steps=0, num_epochs_done=2.0)
tests/unit/eval/test_perplexity.pyadded
@@ -0,0 +1,32 @@
1
+"""Perplexity utility — finite, inf, nan, negative, overflow guards."""
2
+
3
+from __future__ import annotations
4
+
5
+import math
6
+
7
+import pytest
8
+
9
+from dlm.eval.perplexity import perplexity
10
+
11
+
12
+class TestPerplexity:
13
+    def test_zero_loss_gives_one(self) -> None:
14
+        assert perplexity(0.0) == pytest.approx(1.0)
15
+
16
+    def test_positive_loss_gives_exp(self) -> None:
17
+        assert perplexity(1.0) == pytest.approx(math.e)
18
+        assert perplexity(math.log(10.0)) == pytest.approx(10.0)
19
+
20
+    def test_nan_returns_inf(self) -> None:
21
+        assert perplexity(math.nan) == math.inf
22
+
23
+    def test_inf_returns_inf(self) -> None:
24
+        assert perplexity(math.inf) == math.inf
25
+
26
+    def test_negative_returns_inf(self) -> None:
27
+        """Negative cross-entropy loss is nonsense — report inf, not a tiny PPL."""
28
+        assert perplexity(-1.0) == math.inf
29
+
30
+    def test_overflow_returns_inf(self) -> None:
31
+        """`exp(1000.0)` overflows; we substitute inf."""
32
+        assert perplexity(1000.0) == math.inf
tests/unit/eval/test_probes.pyadded
@@ -0,0 +1,81 @@
1
+"""Probe prompt extraction — explicit `!probe` + auto-sample fallback."""
2
+
3
+from __future__ import annotations
4
+
5
+import dataclasses
6
+
7
+import pytest
8
+
9
+from dlm.doc.sections import Section, SectionType
10
+from dlm.eval.probes import Probe, extract_probes
11
+
12
+
13
+class TestExplicitProbes:
14
+    def test_single_probe(self) -> None:
15
+        body = "### Q !probe\nWhat is Paris?\n### A\nCapital of France."
16
+        s = Section(type=SectionType.INSTRUCTION, content=body)
17
+        probes = extract_probes([s], k=3)
18
+        assert len(probes) == 1
19
+        assert probes[0].prompt == "What is Paris?"
20
+        assert probes[0].reference == "Capital of France."
21
+        assert probes[0].section_id == s.section_id
22
+
23
+    def test_multiple_explicit_probes_limited_by_k(self) -> None:
24
+        body = (
25
+            "### Q !probe\nQ1?\n### A\nA1\n\n"
26
+            "### Q !probe\nQ2?\n### A\nA2\n\n"
27
+            "### Q !probe\nQ3?\n### A\nA3"
28
+        )
29
+        s = Section(type=SectionType.INSTRUCTION, content=body)
30
+        probes = extract_probes([s], k=2)
31
+        assert len(probes) == 2
32
+        assert [p.prompt for p in probes] == ["Q1?", "Q2?"]
33
+
34
+    def test_non_probe_questions_ignored_when_explicit_present(self) -> None:
35
+        body = "### Q !probe\nexplicit\n### A\nA1\n\n### Q\nnot-probe\n### A\nA2"
36
+        s = Section(type=SectionType.INSTRUCTION, content=body)
37
+        probes = extract_probes([s], k=3)
38
+        assert len(probes) == 2
39
+        # Explicit one comes first.
40
+        assert probes[0].prompt == "explicit"
41
+        # Auto-sampled fills the remainder.
42
+        assert any(p.prompt == "not-probe" for p in probes)
43
+
44
+
45
+class TestAutoSample:
46
+    def test_auto_sample_when_no_explicit(self) -> None:
47
+        body = "### Q\nQ1?\n### A\nA1\n\n### Q\nQ2?\n### A\nA2\n\n### Q\nQ3?\n### A\nA3"
48
+        s = Section(type=SectionType.INSTRUCTION, content=body)
49
+        probes = extract_probes([s], k=2, seed=42)
50
+        assert len(probes) == 2
51
+
52
+    def test_auto_sample_deterministic(self) -> None:
53
+        body = "\n\n".join(f"### Q\nQ{i}?\n### A\nA{i}" for i in range(10))
54
+        s = Section(type=SectionType.INSTRUCTION, content=body)
55
+        a = extract_probes([s], k=3, seed=7)
56
+        b = extract_probes([s], k=3, seed=7)
57
+        assert [p.prompt for p in a] == [p.prompt for p in b]
58
+
59
+    def test_different_seeds_yield_different_picks(self) -> None:
60
+        body = "\n\n".join(f"### Q\nQ{i}?\n### A\nA{i}" for i in range(10))
61
+        s = Section(type=SectionType.INSTRUCTION, content=body)
62
+        a = extract_probes([s], k=3, seed=1)
63
+        b = extract_probes([s], k=3, seed=99)
64
+        assert {p.prompt for p in a} != {p.prompt for p in b}
65
+
66
+    def test_no_instruction_sections_returns_empty(self) -> None:
67
+        """Prose-only docs have nothing to probe — return [] rather than error."""
68
+        s = Section(type=SectionType.PROSE, content="just prose, no Q/A")
69
+        assert extract_probes([s], k=3) == []
70
+
71
+    def test_k_zero_returns_empty(self) -> None:
72
+        body = "### Q !probe\nx\n### A\ny"
73
+        s = Section(type=SectionType.INSTRUCTION, content=body)
74
+        assert extract_probes([s], k=0) == []
75
+
76
+
77
+class TestProbeDataclass:
78
+    def test_probe_is_frozen(self) -> None:
79
+        p = Probe(prompt="hi", reference="hello")
80
+        with pytest.raises(dataclasses.FrozenInstanceError):
81
+            p.prompt = "other"  # type: ignore[misc]
tests/unit/eval/test_retention.pyadded
@@ -0,0 +1,84 @@
1
+"""Retention slice determinism + stability."""
2
+
3
+from __future__ import annotations
4
+
5
+from datetime import datetime
6
+
7
+import pytest
8
+
9
+from dlm.eval.errors import RetentionSliceError
10
+from dlm.eval.retention import build_retention_slice, retention_delta
11
+from dlm.replay.models import IndexEntry
12
+
13
+
14
+def _entry(sid: str) -> IndexEntry:
15
+    return IndexEntry(
16
+        section_id=sid,
17
+        byte_offset=0,
18
+        length=100,
19
+        added_at=datetime(2026, 1, 1),
20
+    )
21
+
22
+
23
+class TestBuildRetentionSlice:
24
+    def test_empty_corpus_raises(self) -> None:
25
+        with pytest.raises(RetentionSliceError, match="empty"):
26
+            build_retention_slice([], frac=0.1, seed=0)
27
+
28
+    def test_frac_out_of_range_raises(self) -> None:
29
+        entries = [_entry(f"{i:016x}") for i in range(10)]
30
+        with pytest.raises(RetentionSliceError, match="frac"):
31
+            build_retention_slice(entries, frac=0.0, seed=0)
32
+        with pytest.raises(RetentionSliceError, match="frac"):
33
+            build_retention_slice(entries, frac=1.5, seed=0)
34
+
35
+    def test_frac_picks_expected_count(self) -> None:
36
+        entries = [_entry(f"{i:016x}") for i in range(100)]
37
+        slice_ = build_retention_slice(entries, frac=0.1, seed=42)
38
+        assert len(slice_.entries) == 10
39
+
40
+    def test_small_corpus_picks_at_least_one(self) -> None:
41
+        """5% of 3 is 0.15 → rounds up to 1."""
42
+        entries = [_entry(f"{i:016x}") for i in range(3)]
43
+        slice_ = build_retention_slice(entries, frac=0.05, seed=0)
44
+        assert len(slice_.entries) == 1
45
+
46
+    def test_seed_stable(self) -> None:
47
+        entries = [_entry(f"{i:016x}") for i in range(50)]
48
+        a = build_retention_slice(entries, frac=0.1, seed=7)
49
+        b = build_retention_slice(entries, frac=0.1, seed=7)
50
+        assert [e.section_id for e in a.entries] == [e.section_id for e in b.entries]
51
+
52
+    def test_different_seeds_different_slice(self) -> None:
53
+        entries = [_entry(f"{i:016x}") for i in range(50)]
54
+        a = build_retention_slice(entries, frac=0.1, seed=1)
55
+        b = build_retention_slice(entries, frac=0.1, seed=999)
56
+        assert {e.section_id for e in a.entries} != {e.section_id for e in b.entries}
57
+
58
+    def test_identical_inputs_identical_slice(self) -> None:
59
+        """Same corpus + same seed → identical slice across calls.
60
+
61
+        Unlike Sprint 08's splitter, the retention slice is NOT
62
+        growth-stable: adding new entries to the corpus can displace
63
+        existing members from the top-k. The spec only requires
64
+        seed-determinism for a fixed input, which this asserts.
65
+        Cross-run comparability comes from reporting the loss delta on
66
+        whatever slice the current run sees, not from freezing the
67
+        slice membership across corpus growth.
68
+        """
69
+        entries = [_entry(f"{i:016x}") for i in range(100)]
70
+        first = build_retention_slice(entries, frac=0.05, seed=42)
71
+        second = build_retention_slice(entries, frac=0.05, seed=42)
72
+        assert first.section_ids == second.section_ids
73
+
74
+
75
+class TestRetentionDelta:
76
+    def test_both_present(self) -> None:
77
+        assert retention_delta(
78
+            current_retention_loss=1.5, previous_retention_loss=1.2
79
+        ) == pytest.approx(0.3)
80
+
81
+    def test_none_when_either_missing(self) -> None:
82
+        assert retention_delta(current_retention_loss=None, previous_retention_loss=1.0) is None
83
+        assert retention_delta(current_retention_loss=1.0, previous_retention_loss=None) is None
84
+        assert retention_delta(current_retention_loss=None, previous_retention_loss=None) is None
tests/unit/eval/test_summary.pyadded
@@ -0,0 +1,104 @@
1
+"""TrainingSummary schema + round-trip."""
2
+
3
+from __future__ import annotations
4
+
5
+import json
6
+from pathlib import Path
7
+
8
+import pytest
9
+from pydantic import ValidationError
10
+
11
+from dlm.eval.summary import (
12
+    ProbeOutput,
13
+    TrainingSummary,
14
+    load_summary,
15
+    save_summary,
16
+    summary_path_for,
17
+)
18
+
19
+
20
+def _summary(**overrides: object) -> TrainingSummary:
21
+    base: dict[str, object] = {
22
+        "run_id": 1,
23
+        "adapter_version": 1,
24
+        "seed": 42,
25
+        "steps": 100,
26
+        "final_train_loss": 1.23,
27
+        "final_val_loss": 1.45,
28
+        "final_val_perplexity": 4.26,
29
+        "early_stopped": False,
30
+        "duration_seconds": 12.5,
31
+        "determinism_class": "strict",
32
+    }
33
+    base.update(overrides)
34
+    return TrainingSummary.model_validate(base)
35
+
36
+
37
+class TestSchema:
38
+    def test_minimal_accepted(self) -> None:
39
+        s = TrainingSummary(run_id=1, adapter_version=1, seed=0)
40
+        assert s.run_id == 1
41
+        assert s.final_train_loss is None
42
+        assert s.probes == []
43
+
44
+    def test_run_id_must_be_positive(self) -> None:
45
+        with pytest.raises(ValidationError):
46
+            TrainingSummary(run_id=0, adapter_version=1, seed=0)
47
+
48
+    def test_adapter_version_must_be_positive(self) -> None:
49
+        with pytest.raises(ValidationError):
50
+            TrainingSummary(run_id=1, adapter_version=0, seed=0)
51
+
52
+    def test_extra_fields_forbidden(self) -> None:
53
+        with pytest.raises(ValidationError):
54
+            TrainingSummary.model_validate(
55
+                {"run_id": 1, "adapter_version": 1, "seed": 0, "bonus": "nope"}
56
+            )
57
+
58
+    def test_frozen(self) -> None:
59
+        s = _summary()
60
+        with pytest.raises(ValidationError):
61
+            s.run_id = 2  # type: ignore[misc]
62
+
63
+    def test_probes_roundtrip(self) -> None:
64
+        s = _summary(
65
+            probes=[
66
+                {"prompt": "Q?", "response": "A.", "reference": "ref", "section_id": "sid0"},
67
+            ]
68
+        )
69
+        assert isinstance(s.probes[0], ProbeOutput)
70
+        assert s.probes[0].prompt == "Q?"
71
+
72
+
73
+class TestSaveLoad:
74
+    def test_round_trip_via_json(self, tmp_path: Path) -> None:
75
+        s = _summary(
76
+            probes=[{"prompt": "Q?", "response": "A.", "reference": None, "section_id": ""}]
77
+        )
78
+        p = tmp_path / "summary.json"
79
+        save_summary(p, s)
80
+        back = load_summary(p)
81
+        assert back == s
82
+
83
+    def test_written_file_is_sorted_pretty_json(self, tmp_path: Path) -> None:
84
+        s = _summary()
85
+        p = tmp_path / "summary.json"
86
+        save_summary(p, s)
87
+        text = p.read_text()
88
+        # Pretty (indented) + trailing newline.
89
+        assert text.endswith("\n")
90
+        assert "  " in text
91
+        # Sorted keys mean `adapter_version` appears before `run_id`.
92
+        data = json.loads(text)
93
+        keys = list(data.keys())
94
+        assert keys == sorted(keys)
95
+
96
+
97
+class TestSummaryPathFor:
98
+    def test_matches_log_stem(self, tmp_path: Path) -> None:
99
+        # summary_path_for normalizes timestamps.
100
+        p = summary_path_for(tmp_path, 7, "2026-04-18T10:15:23")
101
+        assert p.parent == tmp_path
102
+        assert p.name.startswith("train-000007-")
103
+        assert p.suffix == ".json"
104
+        assert ".summary" in p.name
tests/unit/eval/test_val_loss.pyadded
@@ -0,0 +1,61 @@
1
+"""compute_metrics hook + summarize_eval_state."""
2
+
3
+from __future__ import annotations
4
+
5
+import math
6
+from types import SimpleNamespace
7
+
8
+import pytest
9
+
10
+from dlm.eval.val_loss import eval_metrics_from_eval_pred, summarize_eval_state
11
+
12
+
13
+class TestEvalMetricsFromEvalPred:
14
+    def test_returns_perplexity(self) -> None:
15
+        pred = SimpleNamespace(metrics={"eval_loss": math.log(10.0)})
16
+        result = eval_metrics_from_eval_pred(pred)
17
+        assert result["perplexity"] == pytest.approx(10.0)
18
+
19
+    def test_missing_metrics_returns_empty(self) -> None:
20
+        pred = SimpleNamespace(metrics=None)
21
+        assert eval_metrics_from_eval_pred(pred) == {}
22
+
23
+    def test_non_numeric_loss_returns_empty(self) -> None:
24
+        pred = SimpleNamespace(metrics={"eval_loss": "not a number"})
25
+        assert eval_metrics_from_eval_pred(pred) == {}
26
+
27
+    def test_no_metrics_attr(self) -> None:
28
+        """An EvalPrediction without `metrics` attribute returns empty."""
29
+        pred = SimpleNamespace()
30
+        assert eval_metrics_from_eval_pred(pred) == {}
31
+
32
+
33
+class TestSummarizeEvalState:
34
+    def test_picks_last_eval_loss(self) -> None:
35
+        history = [
36
+            {"loss": 2.5, "step": 10},
37
+            {"eval_loss": 2.1, "step": 10},
38
+            {"loss": 2.3, "step": 20},
39
+            {"eval_loss": 1.9, "step": 20},
40
+        ]
41
+        result = summarize_eval_state(history)
42
+        assert result["final_val_loss"] == pytest.approx(1.9)
43
+        assert result["final_val_perplexity"] == pytest.approx(math.exp(1.9))
44
+
45
+    def test_no_eval_loss_in_history(self) -> None:
46
+        result = summarize_eval_state([{"loss": 1.0}])
47
+        assert result["final_val_loss"] is None
48
+        assert result["final_val_perplexity"] is None
49
+
50
+    def test_empty_history(self) -> None:
51
+        result = summarize_eval_state([])
52
+        assert result["final_val_loss"] is None
53
+        assert result["final_val_perplexity"] is None
54
+
55
+    def test_non_numeric_eval_loss_skipped(self) -> None:
56
+        history = [
57
+            {"eval_loss": "oops", "step": 10},
58
+            {"eval_loss": 1.5, "step": 20},
59
+        ]
60
+        result = summarize_eval_state(history)
61
+        assert result["final_val_loss"] == pytest.approx(1.5)
tests/unit/inference/__init__.pyadded
tests/unit/inference/test_generate.pyadded
@@ -0,0 +1,63 @@
1
+"""`build_generate_kwargs` — deterministic vs. sampled argument resolution."""
2
+
3
+from __future__ import annotations
4
+
5
+import pytest
6
+
7
+from dlm.inference.generate import DEFAULT_MAX_NEW_TOKENS, build_generate_kwargs
8
+
9
+
10
+class TestDeterministicPath:
11
+    def test_temperature_zero_is_deterministic(self) -> None:
12
+        kwargs = build_generate_kwargs(max_new_tokens=32, temperature=0.0)
13
+        assert kwargs["do_sample"] is False
14
+        assert kwargs["num_beams"] == 1
15
+        # Temperature must NOT leak through when do_sample=False.
16
+        assert "temperature" not in kwargs
17
+        assert kwargs["max_new_tokens"] == 32
18
+
19
+    def test_default_max_new_tokens(self) -> None:
20
+        kwargs = build_generate_kwargs()
21
+        assert kwargs["max_new_tokens"] == DEFAULT_MAX_NEW_TOKENS
22
+
23
+
24
+class TestSampledPath:
25
+    def test_non_zero_temperature_flips_sampling(self) -> None:
26
+        kwargs = build_generate_kwargs(max_new_tokens=100, temperature=0.7)
27
+        assert kwargs["do_sample"] is True
28
+        assert kwargs["temperature"] == pytest.approx(0.7)
29
+        assert "num_beams" not in kwargs
30
+
31
+    def test_top_p_threaded_when_sampling(self) -> None:
32
+        kwargs = build_generate_kwargs(temperature=0.5, top_p=0.9)
33
+        assert kwargs["top_p"] == pytest.approx(0.9)
34
+
35
+    def test_top_k_threaded_when_sampling(self) -> None:
36
+        kwargs = build_generate_kwargs(temperature=0.5, top_k=40)
37
+        assert kwargs["top_k"] == 40
38
+
39
+    def test_top_p_ignored_on_deterministic_path(self) -> None:
40
+        kwargs = build_generate_kwargs(temperature=0.0, top_p=0.9)
41
+        assert "top_p" not in kwargs
42
+
43
+
44
+class TestCommon:
45
+    def test_repetition_penalty_threaded_both_paths(self) -> None:
46
+        kwargs_det = build_generate_kwargs(temperature=0.0, repetition_penalty=1.1)
47
+        assert kwargs_det["repetition_penalty"] == pytest.approx(1.1)
48
+        kwargs_sample = build_generate_kwargs(temperature=0.5, repetition_penalty=1.1)
49
+        assert kwargs_sample["repetition_penalty"] == pytest.approx(1.1)
50
+
51
+
52
+class TestValidation:
53
+    def test_zero_max_new_tokens_rejected(self) -> None:
54
+        with pytest.raises(ValueError, match="max_new_tokens"):
55
+            build_generate_kwargs(max_new_tokens=0)
56
+
57
+    def test_negative_max_new_tokens_rejected(self) -> None:
58
+        with pytest.raises(ValueError, match="max_new_tokens"):
59
+            build_generate_kwargs(max_new_tokens=-5)
60
+
61
+    def test_negative_temperature_rejected(self) -> None:
62
+        with pytest.raises(ValueError, match="temperature"):
63
+            build_generate_kwargs(temperature=-0.1)
tests/unit/inference/test_loader.pyadded
@@ -0,0 +1,50 @@
1
+"""`build_load_kwargs` — config-assembly without touching HF."""
2
+
3
+from __future__ import annotations
4
+
5
+from dlm.base_models import BASE_MODELS
6
+from dlm.hardware.backend import Backend
7
+from dlm.inference.loader import build_load_kwargs
8
+from dlm.inference.plan import InferencePlan
9
+
10
+
11
+def _plan(**overrides: object) -> InferencePlan:
12
+    base: dict[str, object] = {
13
+        "backend": Backend.CUDA,
14
+        "precision": "bf16",
15
+        "dequantize_on_load": False,
16
+        "attn_implementation": "sdpa",
17
+        "reason": "test",
18
+    }
19
+    base.update(overrides)
20
+    return InferencePlan(**base)  # type: ignore[arg-type]
21
+
22
+
23
+class TestBuildLoadKwargs:
24
+    def test_basic_fp16_kwargs(self) -> None:
25
+        spec = BASE_MODELS["smollm2-135m"]
26
+        plan = _plan(backend=Backend.MPS, precision="fp16")
27
+        kwargs = build_load_kwargs(spec, plan, has_bitsandbytes=False)
28
+        assert kwargs["revision"] == spec.revision
29
+        assert kwargs["attn_implementation"] == "sdpa"
30
+        # No quantization config on non-CUDA.
31
+        assert "quantization_config" not in kwargs
32
+        assert "torch_dtype" in kwargs
33
+
34
+    def test_dequantize_path_omits_bnb_config(self) -> None:
35
+        """dequantize_on_load=True → no BitsAndBytesConfig even if bnb is installed."""
36
+        spec = BASE_MODELS["smollm2-135m"]
37
+        plan = _plan(dequantize_on_load=True, precision="fp16")
38
+        kwargs = build_load_kwargs(spec, plan, has_bitsandbytes=True)
39
+        assert "quantization_config" not in kwargs
40
+
41
+    def test_plain_lora_uses_torch_dtype(self) -> None:
42
+        spec = BASE_MODELS["smollm2-135m"]
43
+        plan = _plan(backend=Backend.CUDA, precision="bf16", dequantize_on_load=False)
44
+        # Has bnb but NO quantization config because this is plain LoRA (the pinned state
45
+        # is checked upstream; the plan encodes the final decision via `dequantize_on_load`
46
+        # + this function's responsibility is only to assemble from the plan).
47
+        # has_bitsandbytes=False → definitely no quantization config.
48
+        kwargs = build_load_kwargs(spec, plan, has_bitsandbytes=False)
49
+        assert "quantization_config" not in kwargs
50
+        assert "torch_dtype" in kwargs
tests/unit/inference/test_plan.pyadded
@@ -0,0 +1,149 @@
1
+"""InferencePlan resolver — audit F05 cross-hardware coverage."""
2
+
3
+from __future__ import annotations
4
+
5
+import json
6
+from pathlib import Path
7
+from types import SimpleNamespace
8
+
9
+from dlm.hardware.backend import Backend
10
+from dlm.inference.plan import InferencePlan, resolve_inference
11
+
12
+
13
+def _caps(
14
+    *,
15
+    backend: Backend,
16
+    supports_bf16: bool = False,
17
+    has_bitsandbytes: bool = False,
18
+    has_flash_attention: bool = False,
19
+) -> object:
20
+    return SimpleNamespace(
21
+        backend=backend,
22
+        supports_bf16=supports_bf16,
23
+        has_bitsandbytes=has_bitsandbytes,
24
+        has_flash_attention=has_flash_attention,
25
+    )
26
+
27
+
28
+def _write_pinned(adapter_dir: Path, *, bnb: str | None) -> None:
29
+    adapter_dir.mkdir(parents=True, exist_ok=True)
30
+    (adapter_dir / "pinned_versions.json").write_text(
31
+        json.dumps({"torch": "2.4.0", "bitsandbytes": bnb})
32
+    )
33
+
34
+
35
+class TestQLoRAOnCUDAWithBnb:
36
+    def test_loads_4bit_native(self, tmp_path: Path) -> None:
37
+        _write_pinned(tmp_path, bnb="0.43.1")
38
+        plan = resolve_inference(
39
+            tmp_path, _caps(backend=Backend.CUDA, supports_bf16=True, has_bitsandbytes=True)
40
+        )
41
+        assert plan.backend == Backend.CUDA
42
+        assert plan.precision == "bf16"
43
+        assert plan.dequantize_on_load is False
44
+        assert "4-bit" in plan.reason
45
+
46
+
47
+class TestQLoRAOnCUDAWithoutBnb:
48
+    def test_dequantizes(self, tmp_path: Path) -> None:
49
+        _write_pinned(tmp_path, bnb="0.43.1")
50
+        plan = resolve_inference(tmp_path, _caps(backend=Backend.CUDA, has_bitsandbytes=False))
51
+        assert plan.dequantize_on_load is True
52
+        assert plan.precision == "fp16"
53
+        assert "bitsandbytes not installed" in plan.reason
54
+
55
+
56
+class TestQLoRAOnMPS:
57
+    """Audit F05 canonical case — CUDA-trained QLoRA resumed on Apple Silicon."""
58
+
59
+    def test_dequantizes_to_fp16(self, tmp_path: Path) -> None:
60
+        _write_pinned(tmp_path, bnb="0.43.1")
61
+        plan = resolve_inference(tmp_path, _caps(backend=Backend.MPS))
62
+        assert plan.backend == Backend.MPS
63
+        assert plan.precision == "fp16"
64
+        assert plan.dequantize_on_load is True
65
+        assert plan.attn_implementation == "sdpa"
66
+        assert "F05" in plan.reason
67
+
68
+
69
+class TestLoRANonCUDA:
70
+    def test_mps_plain_lora(self, tmp_path: Path) -> None:
71
+        _write_pinned(tmp_path, bnb=None)
72
+        plan = resolve_inference(tmp_path, _caps(backend=Backend.MPS))
73
+        assert plan.precision == "fp16"
74
+        assert plan.dequantize_on_load is False
75
+
76
+    def test_cpu_plain_lora(self, tmp_path: Path) -> None:
77
+        _write_pinned(tmp_path, bnb=None)
78
+        plan = resolve_inference(tmp_path, _caps(backend=Backend.CPU))
79
+        assert plan.dequantize_on_load is False
80
+
81
+
82
+class TestLoRAOnCUDA:
83
+    def test_bf16_when_supported(self, tmp_path: Path) -> None:
84
+        _write_pinned(tmp_path, bnb=None)
85
+        plan = resolve_inference(tmp_path, _caps(backend=Backend.CUDA, supports_bf16=True))
86
+        assert plan.precision == "bf16"
87
+        assert plan.dequantize_on_load is False
88
+
89
+    def test_fp16_when_bf16_unsupported(self, tmp_path: Path) -> None:
90
+        _write_pinned(tmp_path, bnb=None)
91
+        plan = resolve_inference(tmp_path, _caps(backend=Backend.CUDA, supports_bf16=False))
92
+        assert plan.precision == "fp16"
93
+
94
+
95
+class TestAttnImplPick:
96
+    def test_flash_attn_when_available(self, tmp_path: Path) -> None:
97
+        _write_pinned(tmp_path, bnb="0.43.1")
98
+        plan = resolve_inference(
99
+            tmp_path,
100
+            _caps(
101
+                backend=Backend.CUDA,
102
+                supports_bf16=True,
103
+                has_bitsandbytes=True,
104
+                has_flash_attention=True,
105
+            ),
106
+        )
107
+        assert plan.attn_implementation == "flash_attention_2"
108
+
109
+    def test_sdpa_default(self, tmp_path: Path) -> None:
110
+        _write_pinned(tmp_path, bnb=None)
111
+        plan = resolve_inference(tmp_path, _caps(backend=Backend.CUDA))
112
+        assert plan.attn_implementation == "sdpa"
113
+
114
+
115
+class TestMissingPinnedFile:
116
+    def test_no_pinned_versions_treated_as_lora(self, tmp_path: Path) -> None:
117
+        """Missing `pinned_versions.json` is conservative: assume LoRA."""
118
+        plan = resolve_inference(tmp_path, _caps(backend=Backend.MPS))
119
+        assert plan.dequantize_on_load is False
120
+
121
+    def test_malformed_pinned_file_treated_as_lora(self, tmp_path: Path) -> None:
122
+        (tmp_path / "pinned_versions.json").write_text("not json {{{")
123
+        plan = resolve_inference(tmp_path, _caps(backend=Backend.MPS))
124
+        assert plan.dequantize_on_load is False
125
+
126
+
127
+class TestPlanSerialization:
128
+    def test_to_dict_is_json_friendly(self, tmp_path: Path) -> None:
129
+        _write_pinned(tmp_path, bnb=None)
130
+        plan = resolve_inference(tmp_path, _caps(backend=Backend.MPS))
131
+        data = plan.to_dict()
132
+        # Round-trip via json to prove serializability.
133
+        encoded = json.dumps(data)
134
+        decoded = json.loads(encoded)
135
+        assert decoded["backend"] == "mps"
136
+        assert decoded["precision"] == "fp16"
137
+
138
+    def test_plan_is_frozen(self, tmp_path: Path) -> None:
139
+        import dataclasses
140
+
141
+        _write_pinned(tmp_path, bnb=None)
142
+        plan = resolve_inference(tmp_path, _caps(backend=Backend.MPS))
143
+        assert isinstance(plan, InferencePlan)
144
+        try:
145
+            plan.precision = "bf16"  # type: ignore[misc]
146
+        except dataclasses.FrozenInstanceError:
147
+            pass
148
+        else:
149
+            raise AssertionError("frozen=True not enforced")
tests/unit/train/test_trainer.pymodified
@@ -213,7 +213,10 @@ class TestRunHappyPath:
213213
         spec = BASE_MODELS["smollm2-135m"]
214214
 
215215
         result = run(
216
-            store, _parsed(), spec, _plan(),
216
+            store,
217
+            _parsed(),
218
+            spec,
219
+            _plan(),
217220
             trainer_factory=_mock_trainer_factory,
218221
         )
219222