tenseleyflow/documentlanguagemodel / fd7e9a0

Browse files

Raise all coverage gates to 100%

* Cover store edge branches

* Cover hardware edge branches

* Cover CLI reporter branches

* Cover eval edge branches

* Cover export helper branches

* Cover export target helper branches

* Cover export runner branches

* Cover export helper modules

* Cover multimodal export branches

* Finish export coverage sweep

* Cover mlx backend staging refusal

* Cover replay decode edge cases

* Cover preference phase plumbing

* Finish doc and registry edge coverage

* Cover data helper edge branches

* Finish data resample coverage

* Cover train cache toggles

* Cover tokenization coercion edges

* Finish checkpoint commit coverage

* Cover RPC validation edges

* Finish state sidecar coverage

* Finish train helper coverage sweep

* Finish trainer helper coverage

* Finish metrics coverage sweep

* Finish directives coverage sweep

* Cover synth apply and pending

* Finish synth prompt and filter coverage

* Finish synth run coverage

* Cover synth teacher helpers

* Cover synth teacher runtime helpers

* Finish synth teacher coverage

* Retry transient empty lockfiles

* Skip peer roundtrip when bind denied

* Cover preference pending helpers

* Finish preference judge coverage

* Finish preference mine coverage

* Cover modality wrapper modules

* Cover repl and watch helpers

* Cover scaffold and template edges

* Cover app and harvest helpers

* Finish share helper coverage

* Cover share pull orchestrator

* Cover share push orchestrator

* Cover peer token edge cases

* Cover peer runtime flow

* Cover control apply edge branches

* Cover init helper edge cases

* Cover train validation edge cases

* Cover train watch edge cases

* Cover prompt edge cases

* Cover GPU resolve failure

* Cover export edge paths

* Cover export runtime targets

* Cover export run errors

* Cover audio probe generic processor load failure

* Raise coverage gates from 95% to 100%

* Fix import sort order in test_state_sidecar and test_tokenization

* Apply ruff format to 16 test files

* Cover vllm platform helpers on all OS targets

---------

Co-authored-by: mfwolffe <wolffemf@dukes.jmu.edu>
Authored by espadonne
Committed by GitHub
SHA
fd7e9a0261360741d0d6496010c1727f55375496
Parents
20cfe6f
Tree
874199c

121 changed files

StatusFile+-
M .github/workflows/ci.yml 32 32
M scripts/coverage-gates.sh 1 1
M src/dlm/eval/probes.py 0 1
M src/dlm/store/lock.py 8 3
M src/dlm/synth/run.py 0 2
M src/dlm/train/cpt/schedule.py 0 2
M tests/integration/share/test_peer_roundtrip.py 10 3
M tests/unit/base_models/test_probes.py 16 0
A tests/unit/cli/test_app_core.py 91 0
A tests/unit/cli/test_export_edge_paths.py 258 0
A tests/unit/cli/test_export_run_errors.py 191 0
A tests/unit/cli/test_export_target_runtime_paths.py 352 0
A tests/unit/cli/test_init_edges.py 157 0
A tests/unit/cli/test_prompt_edges.py 291 0
M tests/unit/cli/test_reporter.py 27 1
M tests/unit/cli/test_scaffold.py 16 0
M tests/unit/cli/test_train_prompt_repl_coverage.py 10 0
A tests/unit/cli/test_train_validation_edges.py 217 0
A tests/unit/cli/test_train_watch_edges.py 260 0
M tests/unit/control/test_apply.py 34 0
M tests/unit/data/test_audio_cache.py 30 0
M tests/unit/data/test_audio_resample.py 68 22
M tests/unit/data/test_dataset_builder.py 17 0
M tests/unit/data/test_sections_to_rows.py 14 0
M tests/unit/data/test_vl_cache.py 30 0
M tests/unit/data/test_weighted_rows.py 28 0
M tests/unit/directives/test_cache.py 182 0
M tests/unit/directives/test_cache_key.py 14 1
M tests/unit/directives/test_discovery.py 37 0
M tests/unit/directives/test_expand.py 64 0
M tests/unit/directives/test_ignore_parser.py 7 0
M tests/unit/directives/test_merge.py 23 0
M tests/unit/doc/test_parser_roundtrip.py 29 0
M tests/unit/doc/test_serializer_edges.py 12 0
M tests/unit/eval/test_mode_split.py 37 1
M tests/unit/eval/test_probes.py 33 1
M tests/unit/export/ollama/test_modelfile.py 5 0
M tests/unit/export/targets/test_llama_server_argv.py 128 1
M tests/unit/export/targets/test_mlx_serve_argv.py 95 0
M tests/unit/export/targets/test_vllm_argv.py 47 0
M tests/unit/export/test_arch_probe.py 17 0
M tests/unit/export/test_audio_snapshot.py 81 0
M tests/unit/export/test_draft_registry.py 14 0
M tests/unit/export/test_embedding_sync.py 34 0
M tests/unit/export/test_gate_fallback_resolve.py 50 1
A tests/unit/export/test_gguf_io.py 33 0
M tests/unit/export/test_gguf_tensors.py 63 0
M tests/unit/export/test_imatrix.py 99 0
A tests/unit/export/test_merge.py 21 0
M tests/unit/export/test_precision_safety.py 4 0
M tests/unit/export/test_preflight.py 45 0
M tests/unit/export/test_runner.py 274 0
M tests/unit/export/test_smoke.py 21 0
M tests/unit/export/test_vendoring.py 31 0
M tests/unit/export/test_vl_gguf.py 113 1
M tests/unit/export/test_vl_snapshot.py 77 0
M tests/unit/hardware/test_capabilities.py 57 1
M tests/unit/hardware/test_plan.py 52 1
M tests/unit/harvest/test_sway_reader.py 12 0
M tests/unit/inference/test_mlx_backend.py 20 1
M tests/unit/lock/test_mismatch_policy.py 7 0
M tests/unit/metrics/test_queries.py 217 1
M tests/unit/metrics/test_recorder.py 54 2
A tests/unit/metrics/test_sinks.py 168 0
A tests/unit/modality/test_dispatch_modules.py 136 0
A tests/unit/modality/test_vl_contract.py 72 0
M tests/unit/preference/test_cli_judge.py 66 0
M tests/unit/preference/test_hf_reward_judge.py 225 2
M tests/unit/preference/test_mine_dedup.py 69 0
A tests/unit/preference/test_pending.py 199 0
A tests/unit/preference/test_sway_bridge.py 262 0
A tests/unit/repl/test_app_helpers.py 10 0
M tests/unit/repl/test_commands.py 6 1
A tests/unit/repl/test_streaming.py 64 0
M tests/unit/replay/test_corpus.py 15 0
M tests/unit/replay/test_sampler.py 26 1
M tests/unit/share/test_hf_sink.py 21 0
A tests/unit/share/test_peer_runtime.py 388 0
M tests/unit/share/test_peer_tokens.py 41 0
M tests/unit/share/test_provenance.py 63 0
A tests/unit/share/test_pull.py 393 0
A tests/unit/share/test_push.py 435 0
M tests/unit/share/test_signing.py 140 0
M tests/unit/share/test_url_sink.py 81 0
M tests/unit/store/test_blobs.py 26 0
M tests/unit/store/test_inspect.py 56 1
M tests/unit/store/test_lock.py 47 2
M tests/unit/store/test_paths.py 36 1
A tests/unit/synth/test_apply_pending.py 337 0
M tests/unit/synth/test_filter.py 48 0
M tests/unit/synth/test_prompts.py 19 5
M tests/unit/synth/test_run_dry_run.py 118 2
M tests/unit/synth/test_teachers.py 871 22
M tests/unit/templates/test_init.py 28 0
M tests/unit/templates/test_registry.py 16 0
M tests/unit/test_io_atomic.py 39 0
A tests/unit/test_main.py 15 0
A tests/unit/test_package_init.py 24 0
M tests/unit/train/cpt/test_embed_warmup.py 16 0
M tests/unit/train/distributed/test_gpus.py 12 0
M tests/unit/train/distributed/test_rank_env.py 4 0
M tests/unit/train/distributed/test_rank_io.py 5 0
M tests/unit/train/gate/test_module.py 23 0
M tests/unit/train/gate/test_orchestrator.py 16 0
M tests/unit/train/gate/test_trainer.py 18 0
M tests/unit/train/multi_adapter/test_orchestrator.py 66 0
M tests/unit/train/preference/test_dpo_phase.py 30 0
M tests/unit/train/preference/test_orpo_phase.py 30 0
M tests/unit/train/preference/test_phase_orchestrator.py 28 0
A tests/unit/train/test_cache.py 38 0
M tests/unit/train/test_checkpoint_commit.py 82 0
M tests/unit/train/test_inject.py 4 0
M tests/unit/train/test_integrity.py 3 0
M tests/unit/train/test_logger.py 18 0
M tests/unit/train/test_rpc.py 108 0
M tests/unit/train/test_state_sidecar.py 67 0
M tests/unit/train/test_tokenization.py 18 0
M tests/unit/train/test_trainer_helpers.py 300 9
M tests/unit/watch/test_debounce.py 7 0
M tests/unit/watch/test_watcher_filter.py 31 0
A tests/unit/watch/test_watcher_loop.py 55 0
.github/workflows/ci.ymlmodified
165 lines changed — click to load
@@ -63,133 +63,133 @@ jobs:
6363
       - name: Pytest (unit + integration, non-slow)
6464
         run: uv run pytest
6565
 
66
-      - name: Coverage gate — src/dlm/doc ≥ 95% (audit 02 M4)
66
+      - name: Coverage gate — src/dlm/doc = 100% (audit 02 M4)
6767
         if: matrix.os == 'ubuntu-latest'
6868
         run: |
6969
           uv run pytest tests/unit/doc \
7070
             --cov=src/dlm/doc \
7171
             --cov-report=term-missing \
72
-            --cov-fail-under=95
72
+            --cov-fail-under=100
7373
 
74
-      - name: Coverage gate — src/dlm/store ≥ 95% (Sprint 04)
74
+      - name: Coverage gate — src/dlm/store = 100% (Sprint 04)
7575
         if: matrix.os == 'ubuntu-latest'
7676
         run: |
7777
           uv run pytest tests/unit/store \
7878
             --cov=src/dlm/store \
7979
             --cov-report=term-missing \
80
-            --cov-fail-under=95
80
+            --cov-fail-under=100
8181
 
82
-      - name: Coverage gate — src/dlm/hardware ≥ 95% (Sprint 05)
82
+      - name: Coverage gate — src/dlm/hardware = 100% (Sprint 05)
8383
         if: matrix.os == 'ubuntu-latest'
8484
         run: |
8585
           uv run pytest tests/unit/hardware \
8686
             --cov=src/dlm/hardware \
8787
             --cov-report=term-missing \
88
-            --cov-fail-under=95
88
+            --cov-fail-under=100
8989
 
90
-      - name: Coverage gate — src/dlm/base_models ≥ 95% (Sprint 06)
90
+      - name: Coverage gate — src/dlm/base_models = 100% (Sprint 06)
9191
         if: matrix.os == 'ubuntu-latest'
9292
         run: |
9393
           uv run pytest tests/unit/base_models \
9494
             --cov=src/dlm/base_models \
9595
             --cov-report=term-missing \
96
-            --cov-fail-under=95
96
+            --cov-fail-under=100
9797
 
98
-      - name: Coverage gate — src/dlm/data ≥ 95% (Sprint 07)
98
+      - name: Coverage gate — src/dlm/data = 100% (Sprint 07)
9999
         if: matrix.os == 'ubuntu-latest'
100100
         run: |
101101
           uv run pytest tests/unit/data \
102102
             --cov=src/dlm/data \
103103
             --cov-report=term-missing \
104
-            --cov-fail-under=95
104
+            --cov-fail-under=100
105105
 
106
-      - name: Coverage gate — src/dlm/replay ≥ 95% (Sprint 08)
106
+      - name: Coverage gate — src/dlm/replay = 100% (Sprint 08)
107107
         if: matrix.os == 'ubuntu-latest'
108108
         run: |
109109
           uv run pytest tests/unit/replay \
110110
             --cov=src/dlm/replay \
111111
             --cov-report=term-missing \
112
-            --cov-fail-under=95
112
+            --cov-fail-under=100
113113
 
114
-      - name: Coverage gate — src/dlm/train ≥ 95% (Sprint 09)
114
+      - name: Coverage gate — src/dlm/train = 100% (Sprint 09)
115115
         if: matrix.os == 'ubuntu-latest'
116116
         run: |
117117
           uv run pytest tests/unit/train \
118118
             --cov=src/dlm/train \
119119
             --cov-report=term-missing \
120
-            --cov-fail-under=95
120
+            --cov-fail-under=100
121121
 
122
-      - name: Coverage gate — src/dlm/train/preference ≥ 95%
122
+      - name: Coverage gate — src/dlm/train/preference = 100%
123123
         if: matrix.os == 'ubuntu-latest'
124124
         run: |
125125
           uv run pytest tests/unit/train/preference \
126126
             --cov=src/dlm/train/preference \
127127
             --cov-report=term-missing \
128
-            --cov-fail-under=95
128
+            --cov-fail-under=100
129129
 
130
-      - name: Coverage gate — src/dlm/eval ≥ 95% (Sprint 10)
130
+      - name: Coverage gate — src/dlm/eval = 100% (Sprint 10)
131131
         if: matrix.os == 'ubuntu-latest'
132132
         run: |
133133
           uv run pytest tests/unit/eval \
134134
             --cov=src/dlm/eval \
135135
             --cov-report=term-missing \
136
-            --cov-fail-under=95
136
+            --cov-fail-under=100
137137
 
138
-      - name: Coverage gate — src/dlm/inference ≥ 95% (Sprint 10)
138
+      - name: Coverage gate — src/dlm/inference = 100% (Sprint 10)
139139
         if: matrix.os == 'ubuntu-latest'
140140
         run: |
141141
           uv run pytest tests/unit/inference \
142142
             --cov=src/dlm/inference \
143143
             --cov-report=term-missing \
144
-            --cov-fail-under=95
144
+            --cov-fail-under=100
145145
 
146
-      - name: Coverage gate — src/dlm/export ≥ 95% (Sprint 11)
146
+      - name: Coverage gate — src/dlm/export = 100% (Sprint 11)
147147
         if: matrix.os == 'ubuntu-latest'
148148
         run: |
149149
           uv run pytest tests/unit/export \
150150
             --cov=src/dlm/export \
151151
             --cov-report=term-missing \
152
-            --cov-fail-under=95
152
+            --cov-fail-under=100
153153
 
154
-      - name: Coverage gate — src/dlm/export/ollama ≥ 95% (Sprint 12)
154
+      - name: Coverage gate — src/dlm/export/ollama = 100% (Sprint 12)
155155
         if: matrix.os == 'ubuntu-latest'
156156
         run: |
157157
           uv run pytest tests/unit/export/ollama \
158158
             --cov=src/dlm/export/ollama \
159159
             --cov-report=term-missing \
160
-            --cov-fail-under=95
160
+            --cov-fail-under=100
161161
 
162
-      - name: Coverage gate — src/dlm/cli/reporter ≥ 95% (Sprint 13)
162
+      - name: Coverage gate — src/dlm/cli/reporter = 100% (Sprint 13)
163163
         if: matrix.os == 'ubuntu-latest'
164164
         run: |
165165
           uv run pytest tests/unit/cli \
166166
             --cov=dlm.cli.reporter \
167167
             --cov-report=term-missing \
168
-            --cov-fail-under=95
168
+            --cov-fail-under=100
169169
 
170
-      - name: Coverage gate — src/dlm/io/ulid ≥ 95% (Sprint 13)
170
+      - name: Coverage gate — src/dlm/io/ulid = 100% (Sprint 13)
171171
         if: matrix.os == 'ubuntu-latest'
172172
         run: |
173173
           uv run pytest tests/unit/test_io_ulid.py \
174174
             --cov=dlm.io.ulid \
175175
             --cov-report=term-missing \
176
-            --cov-fail-under=95
176
+            --cov-fail-under=100
177177
 
178
-      - name: Coverage gate — src/dlm/pack ≥ 95% (Sprint 14)
178
+      - name: Coverage gate — src/dlm/pack = 100% (Sprint 14)
179179
         if: matrix.os == 'ubuntu-latest'
180180
         run: |
181181
           uv run pytest tests/unit/pack tests/integration/pack \
182182
             --cov=src/dlm/pack \
183183
             --cov-report=term-missing \
184
-            --cov-fail-under=95
184
+            --cov-fail-under=100
185185
 
186
-      - name: Coverage gate — src/dlm/lock ≥ 95% (Sprint 15)
186
+      - name: Coverage gate — src/dlm/lock = 100% (Sprint 15)
187187
         if: matrix.os == 'ubuntu-latest'
188188
         run: |
189189
           uv run pytest tests/unit/lock \
190190
             --cov=src/dlm/lock \
191191
             --cov-report=term-missing \
192
-            --cov-fail-under=95
192
+            --cov-fail-under=100
193193
 
194194
   no-network-sandbox:
195195
     # audit F13: dlm init / doctor / show must work with zero outbound network.
scripts/coverage-gates.shmodified
8 lines changed — click to load
@@ -93,7 +93,7 @@ for gate in "${gates[@]}"; do
9393
         "${tests_arr[@]}" \
9494
         --cov="$cov" \
9595
         --cov-report=term-missing \
96
-        --cov-fail-under=95 \
96
+        --cov-fail-under=100 \
9797
         -q
9898
     echo
9999
 done
src/dlm/eval/probes.pymodified
7 lines changed — click to load
@@ -120,7 +120,6 @@ def _normalize_probe_markers(body: str) -> str:
120120
             # Find the first non-blank body line and prefix it.
121121
             i += 1
122122
             while i < len(lines) and lines[i].strip() == "":
123
-                rewritten.append(lines[i])
124123
                 i += 1
125124
             if i < len(lines):
126125
                 rewritten.append(f"{_PROBE_MARKER}:{lines[i]}")
src/dlm/store/lock.pymodified
17 lines changed — click to load
@@ -179,9 +179,14 @@ def exclusive(
179179
 
180180
         existing = _read_lock(lock_path)
181181
         if existing is None:
182
-            # Malformed lockfile, or a race between stat and read.
183
-            # Treat as stale to avoid infinite contention.
184
-            raise StaleLockError(lock_path, holder_pid=None)
182
+            # Malformed lockfile, or a race between create and payload
183
+            # write. If we still have timeout budget left, treat this as
184
+            # transient and retry; callers with timeout=0 still get the
185
+            # old immediate stale-lock signal.
186
+            if deadline is None or time.monotonic() >= deadline:
187
+                raise StaleLockError(lock_path, holder_pid=None)
188
+            time.sleep(poll_interval)
189
+            continue
185190
 
186191
         if not _is_alive(existing.pid):
187192
             raise StaleLockError(lock_path, holder_pid=existing.pid)
src/dlm/synth/run.pymodified
8 lines changed — click to load
@@ -119,8 +119,6 @@ def build_synth_plan(
119119
         for concrete_strategy, count in _strategy_counts(strategy, per_section):
120120
             if count == 0:
121121
                 continue
122
-            if max_pairs is not None and len(additions) >= max_pairs:
123
-                return SynthRunPlan(additions=tuple(additions), skipped=tuple(skipped))
124122
 
125123
             template = get_prompt_template(concrete_strategy)
126124
             rendered = teacher.generate(
src/dlm/train/cpt/schedule.pymodified
8 lines changed — click to load
@@ -57,8 +57,6 @@ def cosine_with_floor_lr(
5757
         raise ValueError(f"step must be non-negative, got {step}")
5858
 
5959
     if step < warmup_steps:
60
-        if warmup_steps == 0:
61
-            return 1.0
6260
         return step / warmup_steps
6361
 
6462
     if step >= total_steps:
tests/integration/share/test_peer_roundtrip.pymodified
32 lines changed — click to load
@@ -21,8 +21,12 @@ from pathlib import Path
2121
 
2222
 import pytest
2323
 
24
+from dlm.share import ServeHandle
2425
 
25
-def _start_server_in_thread(tmp_path: Path, *, ttl: int = 600):
26
+
27
+def _start_server_in_thread(
28
+    tmp_path: Path, *, ttl: int = 600
29
+) -> tuple[ServeHandle, threading.Thread, bytes]:
2630
     """Helper: pack a trivial file + start the peer server in a thread.
2731
 
2832
     Returns `(handle, thread, pack_bytes)`. Caller stops via
@@ -36,7 +40,10 @@ def _start_server_in_thread(tmp_path: Path, *, ttl: int = 600):
3640
     pack.write_bytes(pack_bytes)
3741
 
3842
     opts = ServeOptions(port=0, token_ttl_seconds=ttl)  # port=0 → OS picks free port
39
-    handle = serve("01HZTESTID", pack, opts)
43
+    try:
44
+        handle = serve("01HZTESTID", pack, opts)
45
+    except PermissionError as exc:
46
+        pytest.skip(f"loopback bind blocked on this host: {exc}")
4047
 
4148
     thread = threading.Thread(target=handle._server.serve_forever, daemon=True)
4249
     thread.start()
@@ -46,7 +53,7 @@ def _start_server_in_thread(tmp_path: Path, *, ttl: int = 600):
4653
     return handle, thread, pack_bytes
4754
 
4855
 
49
-def _stop_server(handle, thread: threading.Thread) -> None:
56
+def _stop_server(handle: ServeHandle, thread: threading.Thread) -> None:
5057
     handle._server.shutdown()
5158
     handle._server.server_close()
5259
     thread.join(timeout=2.0)
tests/unit/base_models/test_probes.pymodified
28 lines changed — click to load
@@ -152,6 +152,12 @@ class TestProbeChatTemplate:
152152
         ):
153153
             probe_chat_template(_spec())
154154
 
155
+    def test_load_error_returns_failed_probe(self) -> None:
156
+        with patch("transformers.AutoTokenizer.from_pretrained", side_effect=RuntimeError("boom")):
157
+            result = probe_chat_template(_spec())
158
+        assert result.passed is False
159
+        assert "load failed: RuntimeError: boom" in result.detail
160
+
155161
 
156162
 class TestProbeGgufArch:
157163
     def test_skips_when_vendor_missing(self, tmp_path: Path) -> None:
@@ -478,6 +484,16 @@ class TestProbeAudioToken:
478484
         ):
479485
             probe_audio_token(_audio_spec())
480486
 
487
+    def test_processor_load_generic_error_fails(self) -> None:
488
+        with patch(
489
+            "dlm.base_models._typed_shims.load_auto_processor",
490
+            side_effect=RuntimeError("connection reset"),
491
+        ):
492
+            result = probe_audio_token(_audio_spec())
493
+        assert result.passed is False
494
+        assert "processor load failed" in result.detail
495
+        assert "RuntimeError" in result.detail
496
+
481497
     def test_missing_tokenizer_fails(self) -> None:
482498
         with patch(
483499
             "dlm.base_models._typed_shims.load_auto_processor",
tests/unit/cli/test_app_core.pyadded
91 lines changed — click to load
@@ -0,0 +1,91 @@
1
+"""Direct coverage for top-level CLI app wiring."""
2
+
3
+from __future__ import annotations
4
+
5
+import logging
6
+import runpy
7
+from unittest.mock import patch
8
+
9
+import pytest
10
+import typer
11
+
12
+import dlm.cli.app as cli_app
13
+
14
+
15
+def test_disable_third_party_telemetry_sets_defaults(monkeypatch: pytest.MonkeyPatch) -> None:
16
+    monkeypatch.delenv("HF_HUB_DISABLE_TELEMETRY", raising=False)
17
+    monkeypatch.delenv("DO_NOT_TRACK", raising=False)
18
+    monkeypatch.delenv("TRANSFORMERS_NO_ADVISORY_WARNINGS", raising=False)
19
+
20
+    cli_app._disable_third_party_telemetry()
21
+
22
+    assert cli_app.os.environ["HF_HUB_DISABLE_TELEMETRY"] == "1"
23
+    assert cli_app.os.environ["DO_NOT_TRACK"] == "1"
24
+    assert cli_app.os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] == "1"
25
+
26
+
27
+def test_version_callback_exits_when_requested(capsys: pytest.CaptureFixture[str]) -> None:
28
+    with pytest.raises(typer.Exit) as excinfo:
29
+        cli_app._version_callback(True)
30
+
31
+    assert excinfo.value.exit_code == 0
32
+    assert "dlm " in capsys.readouterr().out
33
+
34
+
35
+def test_version_callback_is_noop_when_flag_is_false(capsys: pytest.CaptureFixture[str]) -> None:
36
+    cli_app._version_callback(False)
37
+    assert capsys.readouterr().out == ""
38
+
39
+
40
+def test_root_sets_home_and_debug_logging(monkeypatch: pytest.MonkeyPatch) -> None:
41
+    monkeypatch.delenv("DLM_HOME", raising=False)
42
+
43
+    with patch("logging.basicConfig") as basic_config:
44
+        cli_app._root(version=False, home="/tmp/dlm-home", verbose=True, quiet=False)
45
+
46
+    assert cli_app.os.environ["DLM_HOME"] == "/tmp/dlm-home"
47
+    basic_config.assert_called_once_with(
48
+        level=logging.DEBUG,
49
+        format="%(asctime)s %(name)s %(levelname)s: %(message)s",
50
+        force=True,
51
+    )
52
+
53
+
54
+def test_root_sets_warning_logging_for_quiet_mode() -> None:
55
+    with patch("logging.basicConfig") as basic_config:
56
+        cli_app._root(version=False, home=None, verbose=False, quiet=True)
57
+
58
+    basic_config.assert_called_once_with(
59
+        level=logging.WARNING,
60
+        format="%(asctime)s %(name)s %(levelname)s: %(message)s",
61
+        force=True,
62
+    )
63
+
64
+
65
+def test_root_rejects_verbose_and_quiet_together() -> None:
66
+    with pytest.raises(typer.BadParameter, match="mutually exclusive"):
67
+        cli_app._root(version=False, home=None, verbose=True, quiet=True)
68
+
69
+
70
+def test_main_routes_through_reporter_and_exits() -> None:
71
+    with (
72
+        patch("dlm.cli.reporter.run_with_reporter", return_value=7) as run_with_reporter,
73
+        patch("sys.exit", side_effect=SystemExit(7)) as sys_exit,
74
+        pytest.raises(SystemExit) as excinfo,
75
+    ):
76
+        cli_app.main()
77
+
78
+    assert excinfo.value.code == 7
79
+    run_with_reporter.assert_called_once_with(cli_app.app)
80
+    sys_exit.assert_called_once_with(7)
81
+
82
+
83
+def test_module_main_guard_invokes_main() -> None:
84
+    with (
85
+        patch("dlm.cli.reporter.run_with_reporter", return_value=3),
86
+        patch("sys.exit", side_effect=SystemExit(3)),
87
+        pytest.raises(SystemExit) as excinfo,
88
+    ):
89
+        runpy.run_module("dlm.cli.app", run_name="__main__")
90
+
91
+    assert excinfo.value.code == 3
tests/unit/cli/test_export_edge_paths.pyadded
258 lines changed — click to load
@@ -0,0 +1,258 @@
1
+"""Focused early-branch coverage for `dlm export`."""
2
+
3
+from __future__ import annotations
4
+
5
+from pathlib import Path
6
+from types import SimpleNamespace
7
+
8
+import pytest
9
+from typer.testing import CliRunner
10
+
11
+from dlm.base_models import BaseModelSpec
12
+from dlm.base_models.errors import GatedModelError
13
+from dlm.cli.app import app
14
+from dlm.export.errors import ExportError
15
+
16
+
17
+def _joined_output(result: object) -> str:
18
+    text = getattr(result, "output", "") + getattr(result, "stderr", "")
19
+    return " ".join(text.split())
20
+
21
+
22
+def _scaffold_doc(tmp_path: Path) -> Path:
23
+    doc = tmp_path / "doc.dlm"
24
+    runner = CliRunner()
25
+    result = runner.invoke(
26
+        app,
27
+        [
28
+            "--home",
29
+            str(tmp_path / "home"),
30
+            "init",
31
+            str(doc),
32
+            "--base",
33
+            "smollm2-135m",
34
+        ],
35
+    )
36
+    assert result.exit_code == 0, result.output
37
+    return doc
38
+
39
+
40
+def _spec(*, key: str = "demo-1b", modality: str = "text") -> BaseModelSpec:
41
+    payload: dict[str, object] = {
42
+        "key": key,
43
+        "hf_id": f"org/{key}",
44
+        "revision": "0123456789abcdef0123456789abcdef01234567",
45
+        "architecture": "DemoForCausalLM",
46
+        "params": 1_000_000_000,
47
+        "target_modules": ["q_proj", "v_proj"],
48
+        "template": "chatml",
49
+        "gguf_arch": "demo",
50
+        "tokenizer_pre": "demo",
51
+        "license_spdx": "Apache-2.0",
52
+        "license_url": None,
53
+        "requires_acceptance": False,
54
+        "redistributable": True,
55
+        "size_gb_fp16": 2.0,
56
+        "context_length": 4096,
57
+        "recommended_seq_len": 2048,
58
+        "modality": modality,
59
+    }
60
+    if modality == "vision-language":
61
+        payload["vl_preprocessor_plan"] = {
62
+            "target_size": [224, 224],
63
+            "image_token": "<image>",
64
+            "num_image_tokens": 256,
65
+        }
66
+    elif modality == "audio-language":
67
+        payload["audio_preprocessor_plan"] = {
68
+            "sample_rate": 16000,
69
+            "audio_token": "<audio>",
70
+            "num_audio_tokens": 64,
71
+            "max_length_seconds": 30.0,
72
+        }
73
+    return BaseModelSpec.model_validate(payload)
74
+
75
+
76
+def _patch_export_runtime(
77
+    monkeypatch: pytest.MonkeyPatch,
78
+    *,
79
+    spec: BaseModelSpec | None = None,
80
+    dispatch: object | None = None,
81
+) -> None:
82
+    monkeypatch.setattr(
83
+        "dlm.base_models.resolve",
84
+        lambda *args, **kwargs: spec or _spec(),
85
+    )
86
+    monkeypatch.setattr(
87
+        "dlm.modality.modality_for",
88
+        lambda model_spec: (
89
+            dispatch
90
+            or SimpleNamespace(
91
+                accepts_images=model_spec.modality == "vision-language",
92
+                accepts_audio=model_spec.modality == "audio-language",
93
+            )
94
+        ),
95
+    )
96
+
97
+
98
+class TestExportEdgePaths:
99
+    def test_gate_fallback_banner_prints_before_gated_base_refusal(
100
+        self,
101
+        tmp_path: Path,
102
+        monkeypatch: pytest.MonkeyPatch,
103
+    ) -> None:
104
+        doc = _scaffold_doc(tmp_path)
105
+        runner = CliRunner()
106
+
107
+        monkeypatch.setattr(
108
+            "dlm.export.gate_fallback.resolve_and_announce",
109
+            lambda store, parsed: SimpleNamespace(
110
+                entries=[("knowledge", 0.7), ("tone", 0.3)],
111
+                banner_lines=["[yellow]gate:[/yellow] using learned adapter prior"],
112
+            ),
113
+        )
114
+        monkeypatch.setattr(
115
+            "dlm.base_models.resolve",
116
+            lambda *args, **kwargs: (_ for _ in ()).throw(
117
+                GatedModelError("org/gated-base", "https://example.test/license")
118
+            ),
119
+        )
120
+
121
+        result = runner.invoke(
122
+            app,
123
+            ["--home", str(tmp_path / "home"), "export", str(doc)],
124
+        )
125
+
126
+        assert result.exit_code == 1, result.output
127
+        text = _joined_output(result)
128
+        assert "using learned adapter prior" in text
129
+        assert "review the license at: https://example.test/license" in text
130
+        assert "accept via `dlm train --i-accept-license` before exporting." in text
131
+
132
+    @pytest.mark.parametrize(
133
+        ("target", "modality", "needle"),
134
+        [
135
+            (
136
+                "vllm",
137
+                "audio-language",
138
+                "--target vllm is not wired for audio-language documents yet",
139
+            ),
140
+            (
141
+                "mlx-serve",
142
+                "audio-language",
143
+                "--target mlx-serve is not wired for audio-language documents yet",
144
+            ),
145
+            (
146
+                "vllm",
147
+                "vision-language",
148
+                "--target vllm is not wired for vision-language documents yet",
149
+            ),
150
+            (
151
+                "mlx-serve",
152
+                "vision-language",
153
+                "--target mlx-serve is not wired for vision-language documents yet",
154
+            ),
155
+        ],
156
+    )
157
+    def test_runtime_targets_refuse_unsupported_modalities(
158
+        self,
159
+        tmp_path: Path,
160
+        monkeypatch: pytest.MonkeyPatch,
161
+        target: str,
162
+        modality: str,
163
+        needle: str,
164
+    ) -> None:
165
+        doc = _scaffold_doc(tmp_path)
166
+        runner = CliRunner()
167
+
168
+        _patch_export_runtime(
169
+            monkeypatch, spec=_spec(key=f"{target}-{modality}", modality=modality)
170
+        )
171
+
172
+        result = runner.invoke(
173
+            app,
174
+            ["--home", str(tmp_path / "home"), "export", str(doc), "--target", target],
175
+        )
176
+
177
+        assert result.exit_code == 2, result.output
178
+        assert needle in _joined_output(result)
179
+
180
+    def test_audio_dispatch_export_error_maps_to_exit_1(
181
+        self,
182
+        tmp_path: Path,
183
+        monkeypatch: pytest.MonkeyPatch,
184
+    ) -> None:
185
+        doc = _scaffold_doc(tmp_path)
186
+        runner = CliRunner()
187
+
188
+        class _AudioDispatch:
189
+            accepts_images = False
190
+            accepts_audio = True
191
+
192
+            def dispatch_export(self, **kwargs: object) -> object:
193
+                raise ExportError("audio snapshot failed")
194
+
195
+        _patch_export_runtime(
196
+            monkeypatch,
197
+            spec=_spec(key="audio-demo", modality="audio-language"),
198
+            dispatch=_AudioDispatch(),
199
+        )
200
+
201
+        result = runner.invoke(
202
+            app,
203
+            ["--home", str(tmp_path / "home"), "export", str(doc)],
204
+        )
205
+
206
+        assert result.exit_code == 1, result.output
207
+        assert "audio snapshot failed" in _joined_output(result)
208
+
209
+    def test_vl_dispatch_export_error_maps_to_exit_1(
210
+        self,
211
+        tmp_path: Path,
212
+        monkeypatch: pytest.MonkeyPatch,
213
+    ) -> None:
214
+        doc = _scaffold_doc(tmp_path)
215
+        runner = CliRunner()
216
+
217
+        class _VlDispatch:
218
+            accepts_images = True
219
+            accepts_audio = False
220
+
221
+            def dispatch_export(self, **kwargs: object) -> object:
222
+                raise ExportError("vl snapshot failed")
223
+
224
+        _patch_export_runtime(
225
+            monkeypatch,
226
+            spec=_spec(key="vl-demo", modality="vision-language"),
227
+            dispatch=_VlDispatch(),
228
+        )
229
+
230
+        result = runner.invoke(
231
+            app,
232
+            ["--home", str(tmp_path / "home"), "export", str(doc)],
233
+        )
234
+
235
+        assert result.exit_code == 1, result.output
236
+        assert "vl snapshot failed" in _joined_output(result)
237
+
238
+    def test_invalid_export_plan_value_exits_2(
239
+        self,
240
+        tmp_path: Path,
241
+        monkeypatch: pytest.MonkeyPatch,
242
+    ) -> None:
243
+        doc = _scaffold_doc(tmp_path)
244
+        runner = CliRunner()
245
+
246
+        _patch_export_runtime(monkeypatch)
247
+        monkeypatch.setattr(
248
+            "dlm.export.resolve_export_plan",
249
+            lambda **kwargs: (_ for _ in ()).throw(ValueError("bad export plan")),
250
+        )
251
+
252
+        result = runner.invoke(
253
+            app,
254
+            ["--home", str(tmp_path / "home"), "export", str(doc)],
255
+        )
256
+
257
+        assert result.exit_code == 2, result.output
258
+        assert "bad export plan" in _joined_output(result)
tests/unit/cli/test_export_run_errors.pyadded
191 lines changed — click to load
@@ -0,0 +1,191 @@
1
+"""CLI coverage for generic `run_export(...)` branches."""
2
+
3
+from __future__ import annotations
4
+
5
+from pathlib import Path
6
+from types import SimpleNamespace
7
+from typing import Any
8
+
9
+import pytest
10
+from typer.testing import CliRunner
11
+
12
+from dlm.base_models import BaseModelSpec
13
+from dlm.cli.app import app
14
+from dlm.export.errors import ExportError, PreflightError, SubprocessError, UnsafeMergeError
15
+from dlm.export.ollama.errors import (
16
+    OllamaCreateError,
17
+    OllamaError,
18
+    OllamaSmokeError,
19
+    OllamaVersionError,
20
+)
21
+
22
+
23
+def _joined_output(result: object) -> str:
24
+    text = getattr(result, "output", "") + getattr(result, "stderr", "")
25
+    return " ".join(text.split())
26
+
27
+
28
+def _scaffold_doc(tmp_path: Path) -> Path:
29
+    doc = tmp_path / "doc.dlm"
30
+    runner = CliRunner()
31
+    result = runner.invoke(
32
+        app,
33
+        [
34
+            "--home",
35
+            str(tmp_path / "home"),
36
+            "init",
37
+            str(doc),
38
+            "--base",
39
+            "smollm2-135m",
40
+        ],
41
+    )
42
+    assert result.exit_code == 0, result.output
43
+    return doc
44
+
45
+
46
+def _spec() -> BaseModelSpec:
47
+    return BaseModelSpec.model_validate(
48
+        {
49
+            "key": "demo-1b",
50
+            "hf_id": "org/demo-1b",
51
+            "revision": "0123456789abcdef0123456789abcdef01234567",
52
+            "architecture": "DemoForCausalLM",
53
+            "params": 1_000_000_000,
54
+            "target_modules": ["q_proj", "v_proj"],
55
+            "template": "chatml",
56
+            "gguf_arch": "demo",
57
+            "tokenizer_pre": "demo",
58
+            "license_spdx": "Apache-2.0",
59
+            "license_url": None,
60
+            "requires_acceptance": False,
61
+            "redistributable": True,
62
+            "size_gb_fp16": 2.0,
63
+            "context_length": 4096,
64
+            "recommended_seq_len": 2048,
65
+        }
66
+    )
67
+
68
+
69
+def _patch_export_runtime(monkeypatch: pytest.MonkeyPatch) -> None:
70
+    monkeypatch.setattr("dlm.base_models.resolve", lambda *args, **kwargs: _spec())
71
+    monkeypatch.setattr(
72
+        "dlm.base_models.download_spec",
73
+        lambda *args, **kwargs: SimpleNamespace(path=Path("/tmp/base-cache")),
74
+    )
75
+    monkeypatch.setattr(
76
+        "dlm.modality.modality_for",
77
+        lambda spec: SimpleNamespace(accepts_images=False, accepts_audio=False),
78
+    )
79
+    monkeypatch.setattr(
80
+        "dlm.export.gate_fallback.resolve_and_announce",
81
+        lambda store, parsed: SimpleNamespace(entries=None, banner_lines=[]),
82
+    )
83
+    monkeypatch.setattr(
84
+        "dlm.export.targets.resolve_target",
85
+        lambda name: SimpleNamespace(name="ollama"),
86
+    )
87
+
88
+
89
+class TestExportRunErrors:
90
+    def test_verbose_success_prints_shell_command_and_cached_tag(
91
+        self,
92
+        tmp_path: Path,
93
+        monkeypatch: pytest.MonkeyPatch,
94
+    ) -> None:
95
+        doc = _scaffold_doc(tmp_path)
96
+        runner = CliRunner()
97
+        captured: dict[str, Any] = {}
98
+
99
+        _patch_export_runtime(monkeypatch)
100
+
101
+        def _run_export(
102
+            store: object,
103
+            spec: object,
104
+            plan: object,
105
+            **kwargs: object,
106
+        ) -> object:
107
+            captured.update(kwargs)
108
+            subprocess_runner = kwargs["subprocess_runner"]
109
+            assert callable(subprocess_runner)
110
+            subprocess_runner(["llama-quantize", "--version"])
111
+            return SimpleNamespace(
112
+                cached=True,
113
+                export_dir=tmp_path / "exports" / "Q4_K_M",
114
+                artifacts=[SimpleNamespace(name="base.gguf"), SimpleNamespace(name="adapter.gguf")],
115
+                target="ollama",
116
+                ollama_name="demo-model",
117
+                ollama_version=1,
118
+                smoke_output_first_line="hello smoke",
119
+            )
120
+
121
+        monkeypatch.setattr("dlm.export.run_export", _run_export)
122
+        monkeypatch.setattr(
123
+            "dlm.export.quantize.run_checked", lambda cmd: SimpleNamespace(returncode=0)
124
+        )
125
+
126
+        result = runner.invoke(
127
+            app,
128
+            ["--home", str(tmp_path / "home"), "export", str(doc), "--verbose"],
129
+        )
130
+
131
+        assert result.exit_code == 0, result.output
132
+        text = _joined_output(result)
133
+        assert "$ llama-quantize --version" in text
134
+        assert "(cached base)" in text
135
+        assert "ollama: demo-model (v1)" in text
136
+        assert "smoke: hello smoke" in text
137
+        assert captured["cached_base_dir"] == Path("/tmp/base-cache")
138
+        assert captured["target"] == "ollama"
139
+
140
+    @pytest.mark.parametrize(
141
+        ("error", "needle"),
142
+        [
143
+            (UnsafeMergeError("needs --dequantize"), "merge:"),
144
+            (
145
+                PreflightError(probe="template", detail="template mismatch"),
146
+                "preflight: template mismatch",
147
+            ),
148
+            (
149
+                SubprocessError(
150
+                    cmd=["llama-quantize"],
151
+                    returncode=3,
152
+                    stderr_tail="quantize failed",
153
+                ),
154
+                "subprocess:",
155
+            ),
156
+            (
157
+                OllamaVersionError(detected=(0, 1, 0), required=(0, 6, 0)),
158
+                "ollama:",
159
+            ),
160
+            (OllamaCreateError(stdout="", stderr="create failed"), "ollama create:"),
161
+            (OllamaSmokeError(stdout="", stderr="smoke failed"), "smoke:"),
162
+            (OllamaError("generic ollama error"), "ollama:"),
163
+            (ExportError("plain export failure"), "export:"),
164
+        ],
165
+    )
166
+    def test_run_export_error_mappings_exit_1(
167
+        self,
168
+        tmp_path: Path,
169
+        monkeypatch: pytest.MonkeyPatch,
170
+        error: Exception,
171
+        needle: str,
172
+    ) -> None:
173
+        doc = _scaffold_doc(tmp_path)
174
+        runner = CliRunner()
175
+
176
+        _patch_export_runtime(monkeypatch)
177
+        monkeypatch.setattr(
178
+            "dlm.export.run_export",
179
+            lambda *args, **kwargs: (_ for _ in ()).throw(error),
180
+        )
181
+
182
+        result = runner.invoke(
183
+            app,
184
+            ["--home", str(tmp_path / "home"), "export", str(doc)],
185
+        )
186
+
187
+        assert result.exit_code == 1, result.output
188
+        text = _joined_output(result)
189
+        assert needle in text
190
+        if isinstance(error, OllamaSmokeError):
191
+            assert "re-run with `--no-smoke`" in text
tests/unit/cli/test_export_target_runtime_paths.pyadded
352 lines changed — click to load
@@ -0,0 +1,352 @@
1
+"""CLI coverage for vLLM / MLX runtime-target success and smoke paths."""
2
+
3
+from __future__ import annotations
4
+
5
+from pathlib import Path
6
+from types import SimpleNamespace
7
+from typing import Any
8
+
9
+import pytest
10
+from typer.testing import CliRunner
11
+
12
+from dlm.base_models import BaseModelSpec
13
+from dlm.cli.app import app
14
+from dlm.export.errors import ExportError
15
+
16
+
17
+def _joined_output(result: object) -> str:
18
+    text = getattr(result, "output", "") + getattr(result, "stderr", "")
19
+    return " ".join(text.split())
20
+
21
+
22
+def _scaffold_doc(tmp_path: Path) -> Path:
23
+    doc = tmp_path / "doc.dlm"
24
+    runner = CliRunner()
25
+    result = runner.invoke(
26
+        app,
27
+        [
28
+            "--home",
29
+            str(tmp_path / "home"),
30
+            "init",
31
+            str(doc),
32
+            "--base",
33
+            "smollm2-135m",
34
+        ],
35
+    )
36
+    assert result.exit_code == 0, result.output
37
+    return doc
38
+
39
+
40
+def _spec() -> BaseModelSpec:
41
+    return BaseModelSpec.model_validate(
42
+        {
43
+            "key": "demo-1b",
44
+            "hf_id": "org/demo-1b",
45
+            "revision": "0123456789abcdef0123456789abcdef01234567",
46
+            "architecture": "DemoForCausalLM",
47
+            "params": 1_000_000_000,
48
+            "target_modules": ["q_proj", "v_proj"],
49
+            "template": "chatml",
50
+            "gguf_arch": "demo",
51
+            "tokenizer_pre": "demo",
52
+            "license_spdx": "Apache-2.0",
53
+            "license_url": None,
54
+            "requires_acceptance": False,
55
+            "redistributable": True,
56
+            "size_gb_fp16": 2.0,
57
+            "context_length": 4096,
58
+            "recommended_seq_len": 2048,
59
+        }
60
+    )
61
+
62
+
63
+def _patch_text_export_runtime(monkeypatch: pytest.MonkeyPatch) -> None:
64
+    monkeypatch.setattr("dlm.base_models.resolve", lambda *args, **kwargs: _spec())
65
+    monkeypatch.setattr(
66
+        "dlm.base_models.download_spec",
67
+        lambda *args, **kwargs: SimpleNamespace(path=Path("/tmp/base-cache")),
68
+    )
69
+    monkeypatch.setattr(
70
+        "dlm.modality.modality_for",
71
+        lambda spec: SimpleNamespace(accepts_images=False, accepts_audio=False),
72
+    )
73
+    monkeypatch.setattr(
74
+        "dlm.export.gate_fallback.resolve_and_announce",
75
+        lambda store, parsed: SimpleNamespace(entries=None, banner_lines=[]),
76
+    )
77
+
78
+
79
+class _FakeTarget:
80
+    def __init__(self, name: str, smoke_result: object | None) -> None:
81
+        self.name = name
82
+        self._smoke_result = smoke_result
83
+        self.calls: list[object] = []
84
+
85
+    def smoke_test(self, prepared: object) -> object | None:
86
+        self.calls.append(prepared)
87
+        return self._smoke_result
88
+
89
+
90
+class TestExportRuntimeTargetPaths:
91
+    def test_vllm_target_success_prints_launch_config_and_smoke(
92
+        self,
93
+        tmp_path: Path,
94
+        monkeypatch: pytest.MonkeyPatch,
95
+    ) -> None:
96
+        doc = _scaffold_doc(tmp_path)
97
+        runner = CliRunner()
98
+        captured: dict[str, Any] = {}
99
+        smoke = SimpleNamespace(ok=True, detail="vllm smoke ok")
100
+        fake_target = _FakeTarget("vllm", smoke)
101
+
102
+        _patch_text_export_runtime(monkeypatch)
103
+        monkeypatch.setattr("dlm.export.targets.resolve_target", lambda name: fake_target)
104
+
105
+        def _prepare(**kwargs: object) -> object:
106
+            captured.update(kwargs)
107
+            export_dir = tmp_path / "exports" / "vllm"
108
+            launch = export_dir / "vllm_launch.sh"
109
+            config = export_dir / "vllm_config.json"
110
+            return SimpleNamespace(
111
+                export_dir=export_dir,
112
+                launch_script_path=launch,
113
+                config_path=config,
114
+            )
115
+
116
+        monkeypatch.setattr("dlm.export.targets.prepare_vllm_export", _prepare)
117
+        monkeypatch.setattr(
118
+            "dlm.export.targets.finalize_vllm_export",
119
+            lambda **kwargs: tmp_path / "exports" / "vllm" / "export_manifest.json",
120
+        )
121
+
122
+        result = runner.invoke(
123
+            app,
124
+            [
125
+                "--home",
126
+                str(tmp_path / "home"),
127
+                "export",
128
+                str(doc),
129
+                "--target",
130
+                "vllm",
131
+                "--name",
132
+                "served-demo",
133
+                "--quant",
134
+                "Q4_K_M",
135
+                "--merged",
136
+                "--dequantize",
137
+                "--no-template",
138
+                "--skip-ollama",
139
+                "--no-imatrix",
140
+                "--draft",
141
+                "qwen2.5:0.5b",
142
+            ],
143
+        )
144
+
145
+        assert result.exit_code == 0, result.output
146
+        text = _joined_output(result)
147
+        assert "ignoring flags not applicable to `--target vllm`" in text
148
+        assert "--quant" in text
149
+        assert "--merged" in text
150
+        assert "--dequantize" in text
151
+        assert "--no-template" in text
152
+        assert "--skip-ollama" in text
153
+        assert "--no-imatrix" in text
154
+        assert "--draft" in text
155
+        assert "target: vllm" in text
156
+        assert "launch: vllm_launch.sh" in text
157
+        assert "config: vllm_config.json" in text
158
+        assert "manifest: export_manifest.json" in text
159
+        assert "smoke: vllm smoke ok" in text
160
+        assert captured["served_model_name"] == "served-demo"
161
+        assert captured["training_sequence_len"] == 2048
162
+        assert fake_target.calls
163
+
164
+    def test_vllm_target_prepare_error_exits_1(
165
+        self,
166
+        tmp_path: Path,
167
+        monkeypatch: pytest.MonkeyPatch,
168
+    ) -> None:
169
+        doc = _scaffold_doc(tmp_path)
170
+        runner = CliRunner()
171
+
172
+        _patch_text_export_runtime(monkeypatch)
173
+        monkeypatch.setattr(
174
+            "dlm.export.targets.resolve_target",
175
+            lambda name: _FakeTarget("vllm", None),
176
+        )
177
+        monkeypatch.setattr(
178
+            "dlm.export.targets.prepare_vllm_export",
179
+            lambda **kwargs: (_ for _ in ()).throw(ExportError("vllm prepare failed")),
180
+        )
181
+
182
+        result = runner.invoke(
183
+            app,
184
+            ["--home", str(tmp_path / "home"), "export", str(doc), "--target", "vllm"],
185
+        )
186
+
187
+        assert result.exit_code == 1, result.output
188
+        assert "vllm prepare failed" in _joined_output(result)
189
+
190
+    def test_vllm_target_smoke_failure_exits_1(
191
+        self,
192
+        tmp_path: Path,
193
+        monkeypatch: pytest.MonkeyPatch,
194
+    ) -> None:
195
+        doc = _scaffold_doc(tmp_path)
196
+        runner = CliRunner()
197
+        fake_target = _FakeTarget("vllm", SimpleNamespace(ok=False, detail="vllm smoke failed"))
198
+
199
+        _patch_text_export_runtime(monkeypatch)
200
+        monkeypatch.setattr("dlm.export.targets.resolve_target", lambda name: fake_target)
201
+        monkeypatch.setattr(
202
+            "dlm.export.targets.prepare_vllm_export",
203
+            lambda **kwargs: SimpleNamespace(
204
+                export_dir=tmp_path / "exports" / "vllm",
205
+                launch_script_path=tmp_path / "exports" / "vllm" / "vllm_launch.sh",
206
+                config_path=tmp_path / "exports" / "vllm" / "vllm_config.json",
207
+            ),
208
+        )
209
+        monkeypatch.setattr(
210
+            "dlm.export.targets.finalize_vllm_export",
211
+            lambda **kwargs: tmp_path / "exports" / "vllm" / "export_manifest.json",
212
+        )
213
+
214
+        result = runner.invoke(
215
+            app,
216
+            ["--home", str(tmp_path / "home"), "export", str(doc), "--target", "vllm"],
217
+        )
218
+
219
+        assert result.exit_code == 1, result.output
220
+        text = _joined_output(result)
221
+        assert "vllm smoke failed" in text
222
+        assert "re-run with `--no-smoke`" in text
223
+
224
+    def test_mlx_target_success_prints_launch_manifest_and_smoke(
225
+        self,
226
+        tmp_path: Path,
227
+        monkeypatch: pytest.MonkeyPatch,
228
+    ) -> None:
229
+        doc = _scaffold_doc(tmp_path)
230
+        runner = CliRunner()
231
+        captured: dict[str, Any] = {}
232
+        smoke = SimpleNamespace(ok=True, detail="mlx smoke ok")
233
+        fake_target = _FakeTarget("mlx-serve", smoke)
234
+
235
+        _patch_text_export_runtime(monkeypatch)
236
+        monkeypatch.setattr("dlm.export.targets.resolve_target", lambda name: fake_target)
237
+
238
+        def _prepare(**kwargs: object) -> object:
239
+            captured.update(kwargs)
240
+            export_dir = tmp_path / "exports" / "mlx-serve"
241
+            launch = export_dir / "mlx_serve_launch.sh"
242
+            return SimpleNamespace(
243
+                export_dir=export_dir,
244
+                launch_script_path=launch,
245
+            )
246
+
247
+        monkeypatch.setattr("dlm.export.targets.prepare_mlx_serve_export", _prepare)
248
+        monkeypatch.setattr(
249
+            "dlm.export.targets.finalize_mlx_serve_export",
250
+            lambda **kwargs: tmp_path / "exports" / "mlx-serve" / "export_manifest.json",
251
+        )
252
+
253
+        result = runner.invoke(
254
+            app,
255
+            [
256
+                "--home",
257
+                str(tmp_path / "home"),
258
+                "export",
259
+                str(doc),
260
+                "--target",
261
+                "mlx-serve",
262
+                "--name",
263
+                "ignored-name",
264
+                "--quant",
265
+                "Q4_K_M",
266
+                "--merged",
267
+                "--dequantize",
268
+                "--no-template",
269
+                "--skip-ollama",
270
+                "--no-imatrix",
271
+                "--draft",
272
+                "qwen2.5:0.5b",
273
+            ],
274
+        )
275
+
276
+        assert result.exit_code == 0, result.output
277
+        text = _joined_output(result)
278
+        assert "ignoring flags not applicable to `--target mlx-serve`" in text
279
+        assert "--name" in text
280
+        assert "--quant" in text
281
+        assert "--merged" in text
282
+        assert "--dequantize" in text
283
+        assert "--no-template" in text
284
+        assert "--skip-ollama" in text
285
+        assert "--no-imatrix" in text
286
+        assert "--draft" in text
287
+        assert "target: mlx-serve" in text
288
+        assert "launch: mlx_serve_launch.sh" in text
289
+        assert "manifest: export_manifest.json" in text
290
+        assert "smoke: mlx smoke ok" in text
291
+        assert captured["adapter_name"] is None
292
+        assert captured["adapter_path_override"] is None
293
+        assert fake_target.calls
294
+
295
+    def test_mlx_target_prepare_error_exits_1(
296
+        self,
297
+        tmp_path: Path,
298
+        monkeypatch: pytest.MonkeyPatch,
299
+    ) -> None:
300
+        doc = _scaffold_doc(tmp_path)
301
+        runner = CliRunner()
302
+
303
+        _patch_text_export_runtime(monkeypatch)
304
+        monkeypatch.setattr(
305
+            "dlm.export.targets.resolve_target",
306
+            lambda name: _FakeTarget("mlx-serve", None),
307
+        )
308
+        monkeypatch.setattr(
309
+            "dlm.export.targets.prepare_mlx_serve_export",
310
+            lambda **kwargs: (_ for _ in ()).throw(ExportError("mlx prepare failed")),
311
+        )
312
+
313
+        result = runner.invoke(
314
+            app,
315
+            ["--home", str(tmp_path / "home"), "export", str(doc), "--target", "mlx-serve"],
316
+        )
317
+
318
+        assert result.exit_code == 1, result.output
319
+        assert "mlx prepare failed" in _joined_output(result)
320
+
321
+    def test_mlx_target_smoke_failure_exits_1(
322
+        self,
323
+        tmp_path: Path,
324
+        monkeypatch: pytest.MonkeyPatch,
325
+    ) -> None:
326
+        doc = _scaffold_doc(tmp_path)
327
+        runner = CliRunner()
328
+        fake_target = _FakeTarget("mlx-serve", SimpleNamespace(ok=False, detail="mlx smoke failed"))
329
+
330
+        _patch_text_export_runtime(monkeypatch)
331
+        monkeypatch.setattr("dlm.export.targets.resolve_target", lambda name: fake_target)
332
+        monkeypatch.setattr(
333
+            "dlm.export.targets.prepare_mlx_serve_export",
334
+            lambda **kwargs: SimpleNamespace(
335
+                export_dir=tmp_path / "exports" / "mlx-serve",
336
+                launch_script_path=tmp_path / "exports" / "mlx-serve" / "mlx_serve_launch.sh",
337
+            ),
338
+        )
339
+        monkeypatch.setattr(
340
+            "dlm.export.targets.finalize_mlx_serve_export",
341
+            lambda **kwargs: tmp_path / "exports" / "mlx-serve" / "export_manifest.json",
342
+        )
343
+
344
+        result = runner.invoke(
345
+            app,
346
+            ["--home", str(tmp_path / "home"), "export", str(doc), "--target", "mlx-serve"],
347
+        )
348
+
349
+        assert result.exit_code == 1, result.output
350
+        text = _joined_output(result)
351
+        assert "mlx smoke failed" in text
352
+        assert "re-run with `--no-smoke`" in text
tests/unit/cli/test_init_edges.pyadded
157 lines changed — click to load
@@ -0,0 +1,157 @@
1
+"""Edge coverage for `dlm init` helper paths near the top of cli/commands."""
2
+
3
+from __future__ import annotations
4
+
5
+from pathlib import Path
6
+from types import SimpleNamespace
7
+
8
+import pytest
9
+from rich.console import Console
10
+from typer.testing import CliRunner
11
+
12
+import dlm.base_models as base_models
13
+import dlm.templates as templates
14
+from dlm.base_models.errors import GatedModelError
15
+from dlm.cli import commands
16
+from dlm.cli.app import app
17
+from dlm.templates.errors import TemplateError
18
+
19
+
20
+def test_stub_mentions_sprint_and_subject() -> None:
21
+    with pytest.raises(NotImplementedError, match="owned by Sprint 43"):
22
+        commands._stub("43", "dlm synth")
23
+
24
+
25
+class TestPromptAcceptLicense:
26
+    def test_non_tty_returns_false(self, monkeypatch: pytest.MonkeyPatch) -> None:
27
+        console = Console(record=True)
28
+        monkeypatch.setattr("sys.stdin.isatty", lambda: False)
29
+
30
+        assert commands._prompt_accept_license(console, "llama-3.2-1b", None) is False
31
+
32
+    def test_yes_accepts_and_prints_license_url(self, monkeypatch: pytest.MonkeyPatch) -> None:
33
+        console = Console(record=True)
34
+        monkeypatch.setattr("sys.stdin.isatty", lambda: True)
35
+        monkeypatch.setattr("builtins.input", lambda: "Yes")
36
+
37
+        assert (
38
+            commands._prompt_accept_license(
39
+                console,
40
+                "llama-3.2-1b",
41
+                "https://example.test/license",
42
+            )
43
+            is True
44
+        )
45
+        text = console.export_text()
46
+        assert "requires accepting the upstream license" in text
47
+        assert "https://example.test/license" in text
48
+
49
+    def test_eof_returns_false(self, monkeypatch: pytest.MonkeyPatch) -> None:
50
+        console = Console(record=True)
51
+        monkeypatch.setattr("sys.stdin.isatty", lambda: True)
52
+
53
+        def _raise_eof() -> str:
54
+            raise EOFError
55
+
56
+        monkeypatch.setattr("builtins.input", _raise_eof)
57
+
58
+        assert commands._prompt_accept_license(console, "llama-3.2-1b", None) is False
59
+
60
+
61
+class TestInitTemplateEdges:
62
+    def test_explicit_base_warning_when_template_overrides(self, tmp_path: Path) -> None:
63
+        runner = CliRunner()
64
+        out = tmp_path / "doc.dlm"
65
+        home = tmp_path / "home"
66
+
67
+        result = runner.invoke(
68
+            app,
69
+            [
70
+                "--home",
71
+                str(home),
72
+                "init",
73
+                str(out),
74
+                "--base",
75
+                "smollm2-135m",
76
+                "--template",
77
+                "changelog",
78
+            ],
79
+        )
80
+
81
+        assert result.exit_code == 0, result.output
82
+        joined = " ".join((result.output + result.stderr).split())
83
+        assert "--base smollm2-135m ignored" in joined
84
+        assert "uses smollm2-360m" in joined
85
+
86
+    def test_interactive_acceptance_retries_resolution(
87
+        self,
88
+        tmp_path: Path,
89
+        monkeypatch: pytest.MonkeyPatch,
90
+    ) -> None:
91
+        runner = CliRunner()
92
+        out = tmp_path / "doc.dlm"
93
+        home = tmp_path / "home"
94
+        calls: list[tuple[str, bool, bool]] = []
95
+        spec = SimpleNamespace(key="llama-3.2-1b", revision="rev-1", modality="text")
96
+
97
+        def _fake_resolve(
98
+            base: str,
99
+            *,
100
+            accept_license: bool = False,
101
+            skip_export_probes: bool = False,
102
+        ) -> object:
103
+            calls.append((base, accept_license, skip_export_probes))
104
+            if len(calls) == 1:
105
+                raise GatedModelError(base, "https://example.test/license")
106
+            return spec
107
+
108
+        monkeypatch.setattr(base_models, "resolve", _fake_resolve)
109
+        monkeypatch.setattr(base_models, "is_gated", lambda spec: False)
110
+        monkeypatch.setattr(commands, "_prompt_accept_license", lambda console, base, url: True)
111
+
112
+        result = runner.invoke(
113
+            app,
114
+            ["--home", str(home), "init", str(out), "--base", "llama-3.2-1b"],
115
+        )
116
+
117
+        assert result.exit_code == 0, result.output
118
+        assert calls == [
119
+            ("llama-3.2-1b", False, False),
120
+            ("llama-3.2-1b", True, False),
121
+        ]
122
+        assert out.exists()
123
+
124
+    def test_template_apply_error_exits_cleanly(
125
+        self,
126
+        tmp_path: Path,
127
+        monkeypatch: pytest.MonkeyPatch,
128
+    ) -> None:
129
+        runner = CliRunner()
130
+        out = tmp_path / "doc.dlm"
131
+        home = tmp_path / "home"
132
+
133
+        monkeypatch.setattr(
134
+            templates,
135
+            "load_template",
136
+            lambda name: SimpleNamespace(meta=SimpleNamespace(recommended_base="smollm2-135m")),
137
+        )
138
+
139
+        def _fake_apply_template(
140
+            name: str,
141
+            target: Path,
142
+            *,
143
+            force: bool = False,
144
+            accept_license: bool = False,
145
+        ) -> object:
146
+            raise TemplateError("template exploded")
147
+
148
+        monkeypatch.setattr(templates, "apply_template", _fake_apply_template)
149
+
150
+        result = runner.invoke(
151
+            app,
152
+            ["--home", str(home), "init", str(out), "--template", "custom"],
153
+        )
154
+
155
+        assert result.exit_code == 1
156
+        assert "template exploded" in result.output
157
+        assert not out.exists()
tests/unit/cli/test_prompt_edges.pyadded
291 lines changed — click to load
@@ -0,0 +1,291 @@
1
+"""Focused `dlm prompt` edge coverage for the remaining text/VL/audio branches."""
2
+
3
+from __future__ import annotations
4
+
5
+from pathlib import Path
6
+from types import SimpleNamespace
7
+from typing import Any
8
+
9
+import pytest
10
+from typer.testing import CliRunner
11
+
12
+from dlm.base_models import BaseModelSpec
13
+from dlm.cli.app import app
14
+
15
+
16
+def _write_doc(path: Path, *, base_model: str = "demo-1b") -> None:
17
+    path.write_text(
18
+        f"---\ndlm_id: 01HZ4X7TGZM3J1A2B3C4D5E6F7\nbase_model: {base_model}\n---\nbody\n",
19
+        encoding="utf-8",
20
+    )
21
+
22
+
23
+def _joined_output(result: object) -> str:
24
+    text = getattr(result, "output", "") + getattr(result, "stderr", "")
25
+    return " ".join(text.split())
26
+
27
+
28
+def _spec(*, key: str = "demo-1b", modality: str = "text") -> BaseModelSpec:
29
+    payload: dict[str, object] = {
30
+        "key": key,
31
+        "hf_id": f"org/{key}",
32
+        "revision": "0123456789abcdef0123456789abcdef01234567",
33
+        "architecture": "DemoForCausalLM",
34
+        "params": 1_000_000_000,
35
+        "target_modules": ["q_proj", "v_proj"],
36
+        "template": "chatml",
37
+        "gguf_arch": "demo",
38
+        "tokenizer_pre": "demo",
39
+        "license_spdx": "Apache-2.0",
40
+        "license_url": None,
41
+        "requires_acceptance": False,
42
+        "redistributable": True,
43
+        "size_gb_fp16": 2.0,
44
+        "context_length": 4096,
45
+        "recommended_seq_len": 2048,
46
+        "modality": modality,
47
+    }
48
+    if modality == "vision-language":
49
+        payload["vl_preprocessor_plan"] = {
50
+            "target_size": [224, 224],
51
+            "image_token": "<image>",
52
+            "num_image_tokens": 256,
53
+        }
54
+    elif modality == "audio-language":
55
+        payload["audio_preprocessor_plan"] = {
56
+            "sample_rate": 16000,
57
+            "audio_token": "<audio>",
58
+            "num_audio_tokens": 64,
59
+            "max_length_seconds": 30.0,
60
+        }
61
+    return BaseModelSpec.model_validate(payload)
62
+
63
+
64
+def _patch_prompt_runtime(
65
+    monkeypatch: pytest.MonkeyPatch,
66
+    *,
67
+    spec: BaseModelSpec | None = None,
68
+    dispatch: object | None = None,
69
+) -> None:
70
+    monkeypatch.setattr(
71
+        "dlm.base_models.resolve",
72
+        lambda *args, **kwargs: spec or _spec(),
73
+    )
74
+    monkeypatch.setattr(
75
+        "dlm.hardware.doctor",
76
+        lambda: SimpleNamespace(capabilities=object()),
77
+    )
78
+    monkeypatch.setattr(
79
+        "dlm.modality.modality_for",
80
+        lambda model_spec: (
81
+            dispatch
82
+            or SimpleNamespace(
83
+                accepts_images=model_spec.modality == "vision-language",
84
+                accepts_audio=model_spec.modality == "audio-language",
85
+            )
86
+        ),
87
+    )
88
+
89
+
90
+class TestPromptEdgeBranches:
91
+    def test_invalid_backend_value_exits_2(self, tmp_path: Path) -> None:
92
+        doc = tmp_path / "doc.dlm"
93
+        _write_doc(doc)
94
+        runner = CliRunner()
95
+
96
+        result = runner.invoke(
97
+            app,
98
+            ["--home", str(tmp_path / "home"), "prompt", str(doc), "hello", "--backend", "bogus"],
99
+        )
100
+
101
+        assert result.exit_code == 2, result.output
102
+        assert "--backend must be" in _joined_output(result)
103
+
104
+    def test_gated_base_without_recorded_acceptance_exits_1(
105
+        self,
106
+        tmp_path: Path,
107
+        monkeypatch: pytest.MonkeyPatch,
108
+    ) -> None:
109
+        from dlm.base_models.errors import GatedModelError
110
+
111
+        doc = tmp_path / "doc.dlm"
112
+        _write_doc(doc, base_model="gated-base")
113
+        runner = CliRunner()
114
+
115
+        monkeypatch.setattr(
116
+            "dlm.base_models.resolve",
117
+            lambda *args, **kwargs: (_ for _ in ()).throw(
118
+                GatedModelError("org/gated-base", "https://license.example")
119
+            ),
120
+        )
121
+
122
+        result = runner.invoke(
123
+            app,
124
+            ["--home", str(tmp_path / "home"), "prompt", str(doc), "hello"],
125
+        )
126
+
127
+        assert result.exit_code == 1, result.output
128
+        assert "run `dlm train --i-accept-license` first" in _joined_output(result)
129
+
130
+    def test_unsupported_backend_error_exits_2(
131
+        self,
132
+        tmp_path: Path,
133
+        monkeypatch: pytest.MonkeyPatch,
134
+    ) -> None:
135
+        from dlm.inference.backends.select import UnsupportedBackendError
136
+
137
+        doc = tmp_path / "doc.dlm"
138
+        _write_doc(doc)
139
+        runner = CliRunner()
140
+
141
+        _patch_prompt_runtime(monkeypatch)
142
+        monkeypatch.setattr(
143
+            "dlm.inference.backends.select_backend",
144
+            lambda *args, **kwargs: (_ for _ in ()).throw(
145
+                UnsupportedBackendError("mlx backend unavailable")
146
+            ),
147
+        )
148
+
149
+        result = runner.invoke(
150
+            app,
151
+            ["--home", str(tmp_path / "home"), "prompt", str(doc), "hello", "--backend", "mlx"],
152
+        )
153
+
154
+        assert result.exit_code == 2, result.output
155
+        assert "mlx backend unavailable" in _joined_output(result)
156
+
157
+    def test_verbose_text_prompt_logs_backend_and_generates(
158
+        self,
159
+        tmp_path: Path,
160
+        monkeypatch: pytest.MonkeyPatch,
161
+    ) -> None:
162
+        doc = tmp_path / "doc.dlm"
163
+        _write_doc(doc)
164
+        runner = CliRunner()
165
+        captured: dict[str, Any] = {}
166
+
167
+        class _FakeBackend:
168
+            def load(self, spec: object, store: object, adapter_name: str | None = None) -> None:
169
+                captured["adapter_name"] = adapter_name
170
+
171
+            def generate(self, query: str, **kwargs: object) -> str:
172
+                captured["query"] = query
173
+                captured["kwargs"] = kwargs
174
+                return "ok"
175
+
176
+        _patch_prompt_runtime(monkeypatch)
177
+        monkeypatch.setattr(
178
+            "dlm.inference.backends.select_backend",
179
+            lambda *args, **kwargs: "pytorch",
180
+        )
181
+        monkeypatch.setattr(
182
+            "dlm.inference.backends.build_backend",
183
+            lambda *args, **kwargs: _FakeBackend(),
184
+        )
185
+
186
+        result = runner.invoke(
187
+            app,
188
+            ["--home", str(tmp_path / "home"), "prompt", str(doc), "hello", "--verbose"],
189
+        )
190
+
191
+        assert result.exit_code == 0, result.output
192
+        assert captured["query"] == "hello"
193
+        assert "backend: pytorch" in _joined_output(result)
194
+        kwargs = captured["kwargs"]
195
+        assert isinstance(kwargs, dict)
196
+        assert kwargs["top_p"] is None
197
+
198
+    def test_missing_adapter_maps_to_exit_1(
199
+        self,
200
+        tmp_path: Path,
201
+        monkeypatch: pytest.MonkeyPatch,
202
+    ) -> None:
203
+        from dlm.inference import AdapterNotFoundError
204
+
205
+        doc = tmp_path / "doc.dlm"
206
+        _write_doc(doc)
207
+        runner = CliRunner()
208
+
209
+        class _MissingAdapterBackend:
210
+            def load(self, spec: object, store: object, adapter_name: str | None = None) -> None:
211
+                raise AdapterNotFoundError("missing adapter")
212
+
213
+        _patch_prompt_runtime(monkeypatch)
214
+        monkeypatch.setattr(
215
+            "dlm.inference.backends.select_backend",
216
+            lambda *args, **kwargs: "pytorch",
217
+        )
218
+        monkeypatch.setattr(
219
+            "dlm.inference.backends.build_backend",
220
+            lambda *args, **kwargs: _MissingAdapterBackend(),
221
+        )
222
+
223
+        result = runner.invoke(
224
+            app,
225
+            ["--home", str(tmp_path / "home"), "prompt", str(doc), "hello"],
226
+        )
227
+
228
+        assert result.exit_code == 1, result.output
229
+        assert "missing adapter" in _joined_output(result)
230
+
231
+    def test_vision_language_dispatch_branch_invokes_helper(
232
+        self,
233
+        tmp_path: Path,
234
+        monkeypatch: pytest.MonkeyPatch,
235
+    ) -> None:
236
+        doc = tmp_path / "doc.dlm"
237
+        _write_doc(doc, base_model="vl-demo")
238
+        image = tmp_path / "frame.png"
239
+        image.write_bytes(b"\x89PNG fake")
240
+        runner = CliRunner()
241
+        captured: dict[str, Any] = {}
242
+
243
+        _patch_prompt_runtime(
244
+            monkeypatch,
245
+            spec=_spec(key="vl-demo", modality="vision-language"),
246
+        )
247
+        monkeypatch.setattr(
248
+            "dlm.cli.commands._dispatch_vl_prompt",
249
+            lambda **kwargs: captured.update(kwargs),
250
+        )
251
+
252
+        result = runner.invoke(
253
+            app,
254
+            ["--home", str(tmp_path / "home"), "prompt", str(doc), "hello", "--image", str(image)],
255
+        )
256
+
257
+        assert result.exit_code == 0, result.output
258
+        assert captured["query"] == "hello"
259
+        assert captured["image_paths"] == [image]
260
+        assert captured["spec"].key == "vl-demo"
261
+
262
+    def test_audio_dispatch_branch_invokes_helper(
263
+        self,
264
+        tmp_path: Path,
265
+        monkeypatch: pytest.MonkeyPatch,
266
+    ) -> None:
267
+        doc = tmp_path / "doc.dlm"
268
+        _write_doc(doc, base_model="audio-demo")
269
+        audio = tmp_path / "clip.wav"
270
+        audio.write_bytes(b"fake wav bytes")
271
+        runner = CliRunner()
272
+        captured: dict[str, Any] = {}
273
+
274
+        _patch_prompt_runtime(
275
+            monkeypatch,
276
+            spec=_spec(key="audio-demo", modality="audio-language"),
277
+        )
278
+        monkeypatch.setattr(
279
+            "dlm.cli.commands._dispatch_audio_prompt",
280
+            lambda **kwargs: captured.update(kwargs),
281
+        )
282
+
283
+        result = runner.invoke(
284
+            app,
285
+            ["--home", str(tmp_path / "home"), "prompt", str(doc), "hello", "--audio", str(audio)],
286
+        )
287
+
288
+        assert result.exit_code == 0, result.output
289
+        assert captured["query"] == "hello"
290
+        assert captured["audio_paths"] == [audio]
291
+        assert captured["spec"].key == "audio-demo"
tests/unit/cli/test_reporter.pymodified
46 lines changed — click to load
@@ -4,7 +4,7 @@ from __future__ import annotations
44
 
55
 import pytest
66
 
7
-from dlm.cli.reporter import report_exception, run_with_reporter
7
+from dlm.cli.reporter import _prefix_for, report_exception, run_with_reporter
88
 
99
 
1010
 class TestTier1ParseError:
@@ -44,6 +44,22 @@ class TestTier2DomainError:
4444
         err = capsys.readouterr().err
4545
         assert "export:" in err
4646
 
47
+    def test_typed_error_verbose_env_surfaces_traceback(
48
+        self,
49
+        capsys: pytest.CaptureFixture[str],
50
+        monkeypatch: pytest.MonkeyPatch,
51
+    ) -> None:
52
+        from dlm.export.errors import UnsafeMergeError
53
+
54
+        monkeypatch.setenv("DLM_VERBOSE", "1")
55
+        exc = UnsafeMergeError("needs --dequantize")
56
+        code = report_exception(exc)
57
+
58
+        assert code == 1
59
+        err = capsys.readouterr().err
60
+        assert "export:" in err
61
+        assert "UnsafeMergeError" in err
62
+
4763
 
4864
 class TestPrefixMapping:
4965
     """Each known module prefix gets a distinct colored label."""
@@ -91,6 +107,16 @@ class TestPrefixMapping:
91107
         err = capsys.readouterr().err
92108
         assert "base_model:" in err
93109
 
110
+    def test_doc_prefix_branch_is_mapped(self) -> None:
111
+        doc_error = type("DocError", (Exception,), {"__module__": "dlm.doc.custom"})("boom")
112
+
113
+        assert _prefix_for(doc_error) == "doc"
114
+
115
+    def test_hardware_prefix_branch_is_mapped(self) -> None:
116
+        from dlm.hardware.refusals import ResolutionError
117
+
118
+        assert _prefix_for(ResolutionError("no plan")) == "doctor"
119
+
94120
 
95121
 class TestTier3Uncaught:
96122
     def test_unknown_exception_gets_verbose_hint(self, capsys: pytest.CaptureFixture[str]) -> None:
tests/unit/cli/test_scaffold.pymodified
22 lines changed — click to load
@@ -121,6 +121,22 @@ def test_second_run_reuses_existing(tmp_path: Path) -> None:
121121
     assert second.dlm_id == first.dlm_id
122122
 
123123
 
124
+def test_second_run_reuses_lone_existing_file_when_default_name_is_unmatched(
125
+    tmp_path: Path,
126
+) -> None:
127
+    kwargs = _default_kwargs()
128
+    kwargs["name"] = "notes"
129
+    first = scaffold_train_target(tmp_path, **kwargs)  # type: ignore[arg-type]
130
+
131
+    resume = _default_kwargs()
132
+    resume["base"] = None
133
+    resolved = scaffold_train_target(tmp_path, **resume)  # type: ignore[arg-type]
134
+
135
+    assert resolved.scaffolded is False
136
+    assert resolved.dlm_path == first.dlm_path
137
+    assert resolved.dlm_id == first.dlm_id
138
+
139
+
124140
 # ---- Multi-file disambiguation ---------------------------------------------
125141
 
126142
 
tests/unit/cli/test_train_prompt_repl_coverage.pymodified
16 lines changed — click to load
@@ -284,6 +284,16 @@ class TestTrainCommandCoverage:
284284
         )
285285
         assert _maybe_dispatch_multi_gpu("bogus", ["dlm", "train"], console) == 2
286286
 
287
+        class _BadResolveGpuSpec:
288
+            def resolve(self, device_count: int) -> tuple[int, ...]:
289
+                raise UnsupportedGpuSpecError("gpu index 7 is unavailable")
290
+
291
+        monkeypatch.setattr(
292
+            "dlm.train.distributed.parse_gpus",
293
+            lambda raw: _BadResolveGpuSpec(),
294
+        )
295
+        assert _maybe_dispatch_multi_gpu("7", ["dlm", "train"], console) == 2
296
+
287297
         monkeypatch.setattr("dlm.train.distributed.parse_gpus", lambda raw: _GpuSpec((0,)))
288298
         import torch
289299
 
tests/unit/cli/test_train_validation_edges.pyadded
217 lines changed — click to load
@@ -0,0 +1,217 @@
1
+"""Extra edge coverage for the early validation block in `dlm train`."""
2
+
3
+from __future__ import annotations
4
+
5
+from pathlib import Path
6
+from types import SimpleNamespace
7
+from typing import Any
8
+
9
+import pytest
10
+from typer.testing import CliRunner
11
+
12
+import dlm.base_models as base_models
13
+from dlm.base_models.errors import GatedModelError
14
+from dlm.cli.app import app
15
+from dlm.cli.scaffold import ScaffoldError
16
+from dlm.doc.errors import DlmParseError
17
+
18
+
19
+def _write_minimal_dlm(path: Path) -> None:
20
+    path.write_text(
21
+        "---\n"
22
+        "dlm_id: 01TRAINEDGE0000000000000000\n"
23
+        "base_model: smollm2-135m\n"
24
+        "training:\n"
25
+        "  seed: 42\n"
26
+        "---\n"
27
+        "body\n",
28
+        encoding="utf-8",
29
+    )
30
+
31
+
32
+def _parsed_doc(base_model: str = "smollm2-135m") -> object:
33
+    return SimpleNamespace(
34
+        frontmatter=SimpleNamespace(
35
+            base_model=base_model,
36
+            dlm_id="01TRAINEDGE0000000000000000",
37
+            training=SimpleNamespace(sequence_len=2048),
38
+        )
39
+    )
40
+
41
+
42
+def _resolved_spec(**overrides: Any) -> object:
43
+    defaults: dict[str, Any] = {
44
+        "key": "smollm2-135m",
45
+        "revision": "0123456789abcdef0123456789abcdef01234567",
46
+        "modality": "text",
47
+        "params": 135_000_000,
48
+        "effective_context_length": 2048,
49
+    }
50
+    defaults.update(overrides)
51
+    return SimpleNamespace(**defaults)
52
+
53
+
54
+class TestTrainValidationEdges:
55
+    def test_invalid_phase_refused(self, tmp_path: Path) -> None:
56
+        doc = tmp_path / "doc.dlm"
57
+        _write_minimal_dlm(doc)
58
+
59
+        runner = CliRunner()
60
+        result = runner.invoke(
61
+            app,
62
+            ["--home", str(tmp_path), "train", str(doc), "--phase", "bogus"],
63
+        )
64
+
65
+        assert result.exit_code == 2, result.output
66
+        assert "--phase must be one of sft|preference|all" in result.output
67
+
68
+    def test_resume_and_fresh_refused_together(self, tmp_path: Path) -> None:
69
+        doc = tmp_path / "doc.dlm"
70
+        _write_minimal_dlm(doc)
71
+
72
+        runner = CliRunner()
73
+        result = runner.invoke(
74
+            app,
75
+            ["--home", str(tmp_path), "train", str(doc), "--resume", "--fresh"],
76
+        )
77
+
78
+        assert result.exit_code == 2, result.output
79
+        assert "--resume and --fresh are mutually exclusive" in result.output
80
+
81
+    def test_invalid_policy_refused(self, tmp_path: Path) -> None:
82
+        doc = tmp_path / "doc.dlm"
83
+        _write_minimal_dlm(doc)
84
+
85
+        runner = CliRunner()
86
+        result = runner.invoke(
87
+            app,
88
+            ["--home", str(tmp_path), "train", str(doc), "--policy", "bogus"],
89
+        )
90
+
91
+        assert result.exit_code == 2, result.output
92
+        assert "--policy must be 'permissive' or 'strict'" in result.output
93
+
94
+    def test_multi_gpu_exit_code_propagates(
95
+        self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
96
+    ) -> None:
97
+        doc = tmp_path / "doc.dlm"
98
+        _write_minimal_dlm(doc)
99
+        seen: dict[str, object] = {}
100
+
101
+        monkeypatch.setattr(
102
+            "dlm.hardware.capabilities.probe",
103
+            lambda: SimpleNamespace(supports_bf16=False),
104
+        )
105
+
106
+        def _fake_dispatch(
107
+            gpus: str,
108
+            argv: list[str],
109
+            console: object,
110
+            *,
111
+            mixed_precision: str = "bf16",
112
+        ) -> int | None:
113
+            seen["gpus"] = gpus
114
+            seen["argv"] = argv
115
+            seen["mixed_precision"] = mixed_precision
116
+            return 17
117
+
118
+        monkeypatch.setattr("dlm.cli.commands._maybe_dispatch_multi_gpu", _fake_dispatch)
119
+
120
+        runner = CliRunner()
121
+        result = runner.invoke(
122
+            app,
123
+            ["--home", str(tmp_path), "train", str(doc), "--gpus", "0,1"],
124
+        )
125
+
126
+        assert result.exit_code == 17, result.output
127
+        assert seen["gpus"] == "0,1"
128
+        assert seen["mixed_precision"] == "fp16"
129
+
130
+    def test_scaffold_error_exits_cleanly(
131
+        self,
132
+        tmp_path: Path,
133
+        monkeypatch: pytest.MonkeyPatch,
134
+    ) -> None:
135
+        target = tmp_path / "corpus"
136
+        target.mkdir()
137
+
138
+        def _fake_scaffold(*args: object, **kwargs: object) -> object:
139
+            raise ScaffoldError("bad scaffold", path=target)
140
+
141
+        monkeypatch.setattr("dlm.cli.scaffold.scaffold_train_target", _fake_scaffold)
142
+
143
+        runner = CliRunner()
144
+        result = runner.invoke(
145
+            app,
146
+            ["--home", str(tmp_path), "train", str(target), "--base", "smollm2-135m"],
147
+        )
148
+
149
+        assert result.exit_code == 1, result.output
150
+        assert "bad scaffold" in result.output
151
+
152
+    def test_parse_error_exits_cleanly(
153
+        self,
154
+        tmp_path: Path,
155
+        monkeypatch: pytest.MonkeyPatch,
156
+    ) -> None:
157
+        doc = tmp_path / "doc.dlm"
158
+        _write_minimal_dlm(doc)
159
+        monkeypatch.setattr(
160
+            "dlm.doc.parser.parse_file",
161
+            lambda path: (_ for _ in ()).throw(
162
+                DlmParseError("broken frontmatter", path=doc, line=2, col=1)
163
+            ),
164
+        )
165
+
166
+        runner = CliRunner()
167
+        result = runner.invoke(app, ["--home", str(tmp_path), "train", str(doc)])
168
+
169
+        assert result.exit_code == 1, result.output
170
+        assert "broken frontmatter" in result.output
171
+
172
+    def test_gated_base_refusal_surfaces_license_pointer(
173
+        self,
174
+        tmp_path: Path,
175
+        monkeypatch: pytest.MonkeyPatch,
176
+    ) -> None:
177
+        doc = tmp_path / "doc.dlm"
178
+        _write_minimal_dlm(doc)
179
+
180
+        monkeypatch.setattr("dlm.doc.parser.parse_file", lambda path: _parsed_doc("llama-3.2-1b"))
181
+
182
+        def _fake_resolve(
183
+            base: str,
184
+            *,
185
+            accept_license: bool = False,
186
+            skip_export_probes: bool = False,
187
+        ) -> object:
188
+            raise GatedModelError(base, "https://example.test/license")
189
+
190
+        monkeypatch.setattr(base_models, "resolve", _fake_resolve)
191
+
192
+        runner = CliRunner()
193
+        result = runner.invoke(app, ["--home", str(tmp_path), "train", str(doc)])
194
+
195
+        assert result.exit_code == 1, result.output
196
+        text = " ".join(result.output.split())
197
+        assert "base model 'llama-3.2-1b' is gated" in text
198
+        assert "https://example.test/license" in text
199
+        assert "--i-accept-license" in text
200
+
201
+    def test_doctor_no_plan_refused(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
202
+        doc = tmp_path / "doc.dlm"
203
+        _write_minimal_dlm(doc)
204
+
205
+        monkeypatch.setattr("dlm.doc.parser.parse_file", lambda path: _parsed_doc())
206
+        monkeypatch.setattr(base_models, "resolve", lambda *args, **kwargs: _resolved_spec())
207
+        monkeypatch.setattr("dlm.train.distributed.detect_world_size", lambda: 1)
208
+        monkeypatch.setattr(
209
+            "dlm.hardware.doctor",
210
+            lambda **kwargs: SimpleNamespace(plan=None, capabilities=object()),
211
+        )
212
+
213
+        runner = CliRunner()
214
+        result = runner.invoke(app, ["--home", str(tmp_path), "train", str(doc)])
215
+
216
+        assert result.exit_code == 1, result.output
217
+        assert "no viable training plan for this host" in result.output
tests/unit/cli/test_train_watch_edges.pyadded
260 lines changed — click to load
@@ -0,0 +1,260 @@
1
+"""Additional `dlm train` coverage for lock-mode and watch-loop tails."""
2
+
3
+from __future__ import annotations
4
+
5
+from pathlib import Path
6
+from types import SimpleNamespace
7
+
8
+import pytest
9
+from typer.testing import CliRunner
10
+
11
+import dlm.base_models as base_models
12
+from dlm.cli.app import app
13
+from dlm.modality.errors import ModalityError
14
+from dlm.watch.loop import CycleResult
15
+
16
+
17
+def _write_minimal_dlm(path: Path) -> None:
18
+    path.write_text(
19
+        "---\n"
20
+        "dlm_id: 01TRAINWATCH00000000000000\n"
21
+        "base_model: smollm2-135m\n"
22
+        "training:\n"
23
+        "  seed: 42\n"
24
+        "---\n"
25
+        "body\n",
26
+        encoding="utf-8",
27
+    )
28
+
29
+
30
+def _parsed_doc() -> object:
31
+    return SimpleNamespace(
32
+        frontmatter=SimpleNamespace(
33
+            base_model="smollm2-135m",
34
+            dlm_id="01TRAINWATCH00000000000000",
35
+            training=SimpleNamespace(sequence_len=2048),
36
+        ),
37
+        sections=[SimpleNamespace(content="body")],
38
+    )
39
+
40
+
41
+def _resolved_spec() -> object:
42
+    return SimpleNamespace(
43
+        key="smollm2-135m",
44
+        revision="0123456789abcdef0123456789abcdef01234567",
45
+        modality="text",
46
+        params=135_000_000,
47
+        effective_context_length=2048,
48
+    )
49
+
50
+
51
+def _fake_phase_result(tmp_path: Path) -> object:
52
+    result = SimpleNamespace(
53
+        adapter_version=1,
54
+        steps=3,
55
+        seed=42,
56
+        determinism=SimpleNamespace(class_="strict"),
57
+        adapter_path=tmp_path / "adapter",
58
+        log_path=tmp_path / "train.jsonl",
59
+        final_train_loss=0.25,
60
+        final_val_loss=0.1,
61
+    )
62
+    return SimpleNamespace(phase="sft", result=result)
63
+
64
+
65
+def _install_train_basics(monkeypatch: pytest.MonkeyPatch) -> None:
66
+    monkeypatch.setattr("dlm.doc.parser.parse_file", lambda path: _parsed_doc())
67
+    monkeypatch.setattr(base_models, "resolve", lambda *args, **kwargs: _resolved_spec())
68
+    monkeypatch.setattr("dlm.train.distributed.detect_world_size", lambda: 1)
69
+    monkeypatch.setattr(
70
+        "dlm.hardware.doctor",
71
+        lambda **kwargs: SimpleNamespace(plan=object(), capabilities=object()),
72
+    )
73
+
74
+
75
+class TestTrainLockModeEdges:
76
+    @pytest.mark.parametrize(
77
+        ("flag", "expected"),
78
+        [
79
+            ("--strict-lock", "strict"),
80
+            ("--update-lock", "update"),
81
+            ("--ignore-lock", "ignore"),
82
+        ],
83
+    )
84
+    def test_single_lock_flags_propagate_to_run_phases(
85
+        self,
86
+        tmp_path: Path,
87
+        monkeypatch: pytest.MonkeyPatch,
88
+        flag: str,
89
+        expected: str,
90
+    ) -> None:
91
+        doc = tmp_path / "doc.dlm"
92
+        _write_minimal_dlm(doc)
93
+        _install_train_basics(monkeypatch)
94
+        captured: dict[str, object] = {}
95
+
96
+        def _fake_run_phases(*args: object, **kwargs: object) -> list[object]:
97
+            captured["lock_mode"] = kwargs["lock_mode"]
98
+            return []
99
+
100
+        monkeypatch.setattr("dlm.train.preference.phase_orchestrator.run_phases", _fake_run_phases)
101
+
102
+        result = CliRunner().invoke(
103
+            app,
104
+            ["--home", str(tmp_path), "train", str(doc), flag],
105
+        )
106
+
107
+        assert result.exit_code == 0, result.output
108
+        assert captured["lock_mode"] == expected
109
+
110
+
111
+class TestTrainWatchEdges:
112
+    def test_modality_error_maps_to_training_prefix(
113
+        self,
114
+        tmp_path: Path,
115
+        monkeypatch: pytest.MonkeyPatch,
116
+    ) -> None:
117
+        doc = tmp_path / "doc.dlm"
118
+        _write_minimal_dlm(doc)
119
+        _install_train_basics(monkeypatch)
120
+        monkeypatch.setattr(
121
+            "dlm.train.preference.phase_orchestrator.run_phases",
122
+            lambda *args, **kwargs: (_ for _ in ()).throw(
123
+                ModalityError("processor contract failed")
124
+            ),
125
+        )
126
+
127
+        result = CliRunner().invoke(app, ["--home", str(tmp_path), "train", str(doc)])
128
+
129
+        assert result.exit_code == 1, result.output
130
+        assert "training:" in result.output
131
+        assert "processor contract failed" in result.output
132
+
133
+    def test_watch_rpc_logs_cycle_and_skip_messages(
134
+        self,
135
+        tmp_path: Path,
136
+        monkeypatch: pytest.MonkeyPatch,
137
+    ) -> None:
138
+        doc = tmp_path / "doc.dlm"
139
+        _write_minimal_dlm(doc)
140
+        _install_train_basics(monkeypatch)
141
+        monkeypatch.setenv("DLM_PROBE_TOKEN", "secret")
142
+        fake_phase = _fake_phase_result(tmp_path)
143
+
144
+        class _FakeQueue:
145
+            capacity = 5
146
+
147
+            def drain(self) -> list[object]:
148
+                return []
149
+
150
+        class _FakeServer:
151
+            def __init__(self, *, host: str, port: int, token: str, queue: object) -> None:
152
+                self.address = (host, port)
153
+                self.start_calls = 0
154
+                self.stop_calls = 0
155
+
156
+            def start(self) -> None:
157
+                self.start_calls += 1
158
+
159
+            def stop(self) -> None:
160
+                self.stop_calls += 1
161
+
162
+        def _fake_watch(**kwargs: object) -> int:
163
+            on_cycle = kwargs["on_cycle"]
164
+            on_cycle(
165
+                CycleResult(
166
+                    ran=True,
167
+                    new_sections=1,
168
+                    removed_sections=0,
169
+                    run_result=SimpleNamespace(final_train_loss=0.2, final_val_loss=0.1, steps=4),
170
+                )
171
+            )
172
+            on_cycle(CycleResult(ran=False, new_sections=0, removed_sections=0))
173
+            return 23
174
+
175
+        monkeypatch.setattr(
176
+            "dlm.train.preference.phase_orchestrator.run_phases",
177
+            lambda *args, **kwargs: [fake_phase],
178
+        )
179
+        monkeypatch.setattr("dlm.train.inject.InjectedProbeQueue", _FakeQueue)
180
+        monkeypatch.setattr("dlm.train.rpc.ProbeRpcServer", _FakeServer)
181
+        monkeypatch.setattr("dlm.watch.loop.run_watch", _fake_watch)
182
+
183
+        result = CliRunner().invoke(
184
+            app,
185
+            [
186
+                "--home",
187
+                str(tmp_path),
188
+                "train",
189
+                str(doc),
190
+                "--watch",
191
+                "--listen-rpc",
192
+                "127.0.0.1:7777",
193
+            ],
194
+        )
195
+
196
+        assert result.exit_code == 23, result.output
197
+        normalized = " ".join(result.output.split())
198
+        assert "rpc:" in normalized
199
+        assert "watch:" in normalized
200
+        assert "no new content, skipping retrain" in normalized
201
+
202
+    def test_watch_keyboard_interrupt_stops_server_and_exits_zero(
203
+        self,
204
+        tmp_path: Path,
205
+        monkeypatch: pytest.MonkeyPatch,
206
+    ) -> None:
207
+        doc = tmp_path / "doc.dlm"
208
+        _write_minimal_dlm(doc)
209
+        _install_train_basics(monkeypatch)
210
+        monkeypatch.setenv("DLM_PROBE_TOKEN", "secret")
211
+        fake_phase = _fake_phase_result(tmp_path)
212
+        holder: dict[str, object] = {}
213
+
214
+        class _FakeQueue:
215
+            capacity = 3
216
+
217
+            def drain(self) -> list[object]:
218
+                return []
219
+
220
+        class _FakeServer:
221
+            def __init__(self, *, host: str, port: int, token: str, queue: object) -> None:
222
+                self.address = (host, port)
223
+                self.stop_calls = 0
224
+                holder["server"] = self
225
+
226
+            def start(self) -> None:
227
+                return None
228
+
229
+            def stop(self) -> None:
230
+                self.stop_calls += 1
231
+
232
+        monkeypatch.setattr(
233
+            "dlm.train.preference.phase_orchestrator.run_phases",
234
+            lambda *args, **kwargs: [fake_phase],
235
+        )
236
+        monkeypatch.setattr("dlm.train.inject.InjectedProbeQueue", _FakeQueue)
237
+        monkeypatch.setattr("dlm.train.rpc.ProbeRpcServer", _FakeServer)
238
+        monkeypatch.setattr(
239
+            "dlm.watch.loop.run_watch",
240
+            lambda **kwargs: (_ for _ in ()).throw(KeyboardInterrupt),
241
+        )
242
+
243
+        result = CliRunner().invoke(
244
+            app,
245
+            [
246
+                "--home",
247
+                str(tmp_path),
248
+                "train",
249
+                str(doc),
250
+                "--watch",
251
+                "--listen-rpc",
252
+                "127.0.0.1:7777",
253
+            ],
254
+        )
255
+
256
+        assert result.exit_code == 0, result.output
257
+        assert "Ctrl-C received, exiting" in result.output
258
+        server = holder["server"]
259
+        assert isinstance(server, _FakeServer)
260
+        assert server.stop_calls == 2
tests/unit/control/test_apply.pymodified
49 lines changed — click to load
@@ -8,6 +8,7 @@ import torch
88
 from torch import nn
99
 
1010
 from dlm.control import ControlApplyError, apply_control
11
+from dlm.control.apply import _make_hook
1112
 
1213
 
1314
 class _ToyLayer(nn.Module):
@@ -43,6 +44,12 @@ def _run_through_layer(model: _ToyModel, layer_index: int, hidden: torch.Tensor)
4344
 
4445
 
4546
 class TestHookArithmetic:
47
+    def test_hook_with_no_args_returns_original_tuple(self) -> None:
48
+        hook = _make_hook(torch.ones(4), 1.0)
49
+        empty: tuple[object, ...] = ()
50
+
51
+        assert hook(None, empty) == empty
52
+
4653
     def test_adds_scaled_vector_to_hidden(self) -> None:
4754
         model = _ToyModel(n_layers=4, hidden_dim=8)
4855
         vector = np.ones(8, dtype=np.float32)
@@ -177,3 +184,30 @@ class TestValidation:
177184
         with apply_control(wrapped, vector, layer_index=2, strength=1.0):
178185
             out = _run_through_layer(inner, 2, hidden)
179186
         assert torch.allclose(out, torch.ones_like(out))
187
+
188
+    def test_falls_through_sparse_projection_paths(self) -> None:
189
+        class _SparseProjLayer(nn.Module):
190
+            def __init__(self, hidden_dim: int) -> None:
191
+                super().__init__()
192
+                self.self_attn = nn.Module()
193
+                self.self_attn.q_proj = nn.Module()
194
+                self.attn = nn.Module()
195
+                self.attn.qkv_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
196
+
197
+            def forward(self, hidden: torch.Tensor) -> torch.Tensor:
198
+                return hidden
199
+
200
+        class _SparseProjModel(nn.Module):
201
+            def __init__(self, hidden_dim: int) -> None:
202
+                super().__init__()
203
+                self.model = nn.Module()
204
+                self.model.layers = nn.ModuleList([_SparseProjLayer(hidden_dim)])
205
+
206
+        model = _SparseProjModel(hidden_dim=4)
207
+        vector = np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32)
208
+        hidden = torch.zeros(1, 1, 4)
209
+
210
+        with apply_control(model, vector, layer_index=0, strength=1.0):
211
+            out = model.model.layers[0](hidden)
212
+
213
+        assert out[0, 0, 0].item() == 1.0
tests/unit/data/test_audio_cache.pymodified
36 lines changed — click to load
@@ -183,6 +183,36 @@ class TestProcessorSha256:
183183
         proc_b = SimpleNamespace(feature_extractor=FeB())
184184
         assert processor_sha256(proc_a) != processor_sha256(proc_b)
185185
 
186
+    def test_nested_feature_extractor_fields_are_readable(self) -> None:
187
+        proc = SimpleNamespace(
188
+            feature_extractor=SimpleNamespace(
189
+                sampling_rate=16_000,
190
+                feature_size=(80, 2),
191
+                n_fft=400,
192
+                hop_length=160,
193
+                chunk_length={"seconds": 30},
194
+                padding_value=0.0,
195
+                return_attention_mask=True,
196
+            )
197
+        )
198
+        sha = processor_sha256(proc)
199
+        assert len(sha) == 64
200
+
201
+    def test_exotic_feature_field_stringifies_stably(self) -> None:
202
+        proc = SimpleNamespace(
203
+            feature_extractor=SimpleNamespace(
204
+                sampling_rate=16_000,
205
+                feature_size=80,
206
+                n_fft=400,
207
+                hop_length=160,
208
+                chunk_length=object(),
209
+                padding_value=0.0,
210
+                return_attention_mask=True,
211
+            )
212
+        )
213
+        sha = processor_sha256(proc)
214
+        assert len(sha) == 64
215
+
186216
 
187217
 # --- WaveformCache (35.2 deferred-item follow-up) ---------------------------
188218
 
tests/unit/data/test_audio_resample.pymodified
121 lines changed — click to load
@@ -10,8 +10,9 @@ Covers:
1010
 
1111
 from __future__ import annotations
1212
 
13
+import builtins
1314
 import sys
14
-from types import SimpleNamespace
15
+from types import ModuleType, SimpleNamespace
1516
 
1617
 import numpy as np
1718
 import pytest
@@ -42,24 +43,17 @@ class TestBackendPickFailure:
4243
     def test_no_backend_raises(self, monkeypatch: pytest.MonkeyPatch) -> None:
4344
         """Force both imports to fail and confirm the error names both paths."""
4445
 
45
-        real_import = (
46
-            __builtins__["__import__"]
47
-            if isinstance(__builtins__, dict)
48
-            else __builtins__.__import__
49
-        )
46
+        real_import = builtins.__import__
5047
 
5148
         def fake_import(name: str, *args: object, **kwargs: object) -> object:
5249
             if name in ("soxr", "scipy", "scipy.signal"):
5350
                 raise ImportError(f"forced: {name}")
54
-            return real_import(name, *args, **kwargs)  # type: ignore[operator]
51
+            return real_import(name, *args, **kwargs)
5552
 
56
-        monkeypatch.setitem(sys.modules, "soxr", None)
57
-        # Monkey-patch the _pick_backend helper's import probes so both
58
-        # attempts fail regardless of what's installed in the env.
59
-        monkeypatch.setattr(audio_resample, "_pick_backend", _no_backend)
53
+        monkeypatch.setattr(builtins, "__import__", fake_import)
6054
 
6155
         with pytest.raises(AudioResampleUnavailable, match="soxr or scipy"):
62
-            resample(np.zeros(8, dtype=np.float32), src_sr=48_000, dst_sr=16_000)
56
+            audio_resample._pick_backend()
6357
 
6458
 
6559
 def _no_backend() -> None:
@@ -71,21 +65,73 @@ def _no_backend() -> None:
7165
 
7266
 
7367
 class TestScipyBackend:
74
-    def test_scipy_fallback_resamples(self, monkeypatch: pytest.MonkeyPatch) -> None:
75
-        """With soxr disabled, scipy fallback produces expected length."""
76
-        # Pretend soxr isn't importable so the pick falls through to scipy.
77
-        monkeypatch.setitem(sys.modules, "soxr", None)
68
+    def test_resample_routes_through_selected_backend(
69
+        self, monkeypatch: pytest.MonkeyPatch
70
+    ) -> None:
71
+        called: dict[str, object] = {}
7872
 
79
-        pytest.importorskip("scipy.signal")
73
+        def fake_backend(waveform: np.ndarray, *, src_sr: int, dst_sr: int) -> np.ndarray:
74
+            called["waveform"] = waveform
75
+            called["src_sr"] = src_sr
76
+            called["dst_sr"] = dst_sr
77
+            return np.ones(4, dtype=np.float32)
8078
 
81
-        # 1 second of 8 kHz silence → resample to 16 kHz = 2 s of samples.
82
-        wave = np.zeros(8_000, dtype=np.float32)
79
+        wave = np.zeros(8, dtype=np.float32)
80
+        monkeypatch.setattr(audio_resample, "_pick_backend", lambda: fake_backend)
8381
         out = resample(wave, src_sr=8_000, dst_sr=16_000)
8482
 
83
+        assert out.tolist() == [1.0, 1.0, 1.0, 1.0]
84
+        assert called == {"waveform": wave, "src_sr": 8_000, "dst_sr": 16_000}
85
+
86
+    def test_scipy_fallback_uses_fake_module(self, monkeypatch: pytest.MonkeyPatch) -> None:
87
+        """With soxr disabled, _pick_backend falls through to scipy."""
88
+        monkeypatch.setitem(sys.modules, "soxr", None)
89
+        fake_signal = ModuleType("scipy.signal")
90
+        fake_signal.resample_poly = lambda waveform, *, up, down: np.repeat(waveform, up)[
91
+            : len(waveform) * up // down
92
+        ]
93
+        fake_scipy = ModuleType("scipy")
94
+        fake_scipy.signal = fake_signal
95
+        monkeypatch.setitem(sys.modules, "scipy", fake_scipy)
96
+        monkeypatch.setitem(sys.modules, "scipy.signal", fake_signal)
97
+
98
+        backend = audio_resample._pick_backend()
99
+        assert backend is audio_resample._scipy_resample
100
+
101
+    def test_soxr_backend_coerces_float32_contiguous(self, monkeypatch: pytest.MonkeyPatch) -> None:
102
+        fake_soxr = ModuleType("soxr")
103
+        fake_soxr.resample = lambda waveform, src_sr, dst_sr, quality="HQ": np.asarray(
104
+            waveform * 2, dtype=np.float64
105
+        )
106
+        monkeypatch.setitem(sys.modules, "soxr", fake_soxr)
107
+
108
+        wave = np.arange(6, dtype=np.float32)[::2]
109
+        out = audio_resample._soxr_resample(wave, src_sr=8_000, dst_sr=16_000)
110
+
111
+        assert out.dtype == np.float32
112
+        assert out.flags.c_contiguous
113
+        assert out.tolist() == [0.0, 4.0, 8.0]
114
+
115
+    def test_scipy_backend_reduces_ratio_before_call(self, monkeypatch: pytest.MonkeyPatch) -> None:
116
+        calls: dict[str, object] = {}
117
+
118
+        def fake_resample_poly(waveform: np.ndarray, *, up: int, down: int) -> np.ndarray:
119
+            calls["up"] = up
120
+            calls["down"] = down
121
+            return np.asarray(waveform + 1, dtype=np.float64)
122
+
123
+        fake_signal = ModuleType("scipy.signal")
124
+        fake_signal.resample_poly = fake_resample_poly
125
+        fake_scipy = ModuleType("scipy")
126
+        fake_scipy.signal = fake_signal
127
+        monkeypatch.setitem(sys.modules, "scipy", fake_scipy)
128
+        monkeypatch.setitem(sys.modules, "scipy.signal", fake_signal)
129
+
130
+        wave = np.arange(5, dtype=np.float32)
131
+        out = audio_resample._scipy_resample(wave, src_sr=48_000, dst_sr=16_000)
132
+
133
+        assert calls == {"up": 1, "down": 3}
85134
         assert out.dtype == np.float32
86
-        # resample_poly produces len(x) * up // down on integer ratios.
87
-        # scipy rounds up-or-down depending on filter length; accept ±1.
88
-        assert abs(out.shape[0] - 16_000) <= 1
89135
         assert out.flags.c_contiguous
90136
 
91137
 
tests/unit/data/test_dataset_builder.pymodified
20 lines changed — click to load
@@ -55,3 +55,20 @@ class TestBuildDataset:
5555
         sections = [_s(SectionType.PROSE, "   ")]
5656
         with pytest.raises(ValueError, match="no trainable rows"):
5757
             build_dataset(sections, seed=0, val_frac=0.1)
58
+
59
+    def test_weights_dropping_every_row_raises(self) -> None:
60
+        replay = [
61
+            {
62
+                "text": "replay-only",
63
+                "_dlm_section_id": "replay-v1",
64
+                "_dlm_row_tags": {"lang": "en"},
65
+            },
66
+        ]
67
+        with pytest.raises(ValueError, match="weights dropped every row"):
68
+            build_dataset(
69
+                [],
70
+                seed=0,
71
+                val_frac=0.1,
72
+                replay_rows=replay,
73
+                weights={"lang": {"en": 0.0}},
74
+            )
tests/unit/data/test_sections_to_rows.pymodified
20 lines changed — click to load
@@ -50,6 +50,20 @@ class TestInstructionShape:
5050
         with pytest.raises(InstructionParseError):
5151
             sections_to_rows([s])
5252
 
53
+    def test_probe_markers_normalized_before_parse(self) -> None:
54
+        s = _s(SectionType.INSTRUCTION, "### Q !probe\nq1\n### A\na1")
55
+        rows = sections_to_rows([s])
56
+        assert rows == [
57
+            {
58
+                "messages": [
59
+                    {"role": "user", "content": "q1"},
60
+                    {"role": "assistant", "content": "a1"},
61
+                ],
62
+                "_dlm_section_id": s.section_id,
63
+                "_dlm_row_tags": {},
64
+            },
65
+        ]
66
+
5367
 
5468
 class TestPreferenceShape:
5569
     def test_each_triple_becomes_preference_row(self) -> None:
tests/unit/data/test_vl_cache.pymodified
33 lines changed — click to load
@@ -146,3 +146,33 @@ class TestProcessorSha256:
146146
             image_std = [0.5] * 3
147147
 
148148
         assert processor_sha256(ProcA()) != processor_sha256(ProcB())
149
+
150
+    def test_nested_dict_and_tuple_fields_are_readable(self) -> None:
151
+        proc = SimpleNamespace(
152
+            image_processor=SimpleNamespace(
153
+                size={"shortest_edge": 224, "crop": (224, 224)},
154
+                image_mean=(0.5, 0.5, 0.5),
155
+                image_std=[0.2, 0.2, 0.2],
156
+                do_normalize=True,
157
+                do_rescale=True,
158
+                rescale_factor=1 / 255,
159
+                resample="bicubic",
160
+            )
161
+        )
162
+        sha = processor_sha256(proc)
163
+        assert len(sha) == 64
164
+
165
+    def test_exotic_resample_value_stringifies_stably(self) -> None:
166
+        proc = SimpleNamespace(
167
+            image_processor=SimpleNamespace(
168
+                size={"shortest_edge": 224},
169
+                image_mean=[0.5] * 3,
170
+                image_std=[0.5] * 3,
171
+                do_normalize=True,
172
+                do_rescale=True,
173
+                rescale_factor=1 / 255,
174
+                resample=object(),
175
+            )
176
+        )
177
+        sha = processor_sha256(proc)
178
+        assert len(sha) == 64
tests/unit/data/test_weighted_rows.pymodified
41 lines changed — click to load
@@ -3,7 +3,9 @@
33
 from __future__ import annotations
44
 
55
 from dlm.data.weighted_rows import (
6
+    _keep_fraction,
67
     expand_rows_by_weight,
8
+    merge_weights_maps,
79
     resolve_row_weight,
810
     weight_distribution,
911
 )
@@ -105,6 +107,32 @@ class TestExpandRowsByWeight:
105107
         assert len(out) == 6
106108
 
107109
 
110
+class TestMergeWeightsMaps:
111
+    def test_empty_sequence_returns_empty_map(self) -> None:
112
+        assert merge_weights_maps([]) == {}
113
+
114
+    def test_deeper_entries_override_shallower_ones(self) -> None:
115
+        merged = merge_weights_maps(
116
+            [
117
+                {"lang": {"py": 2.0, "rs": 1.5}, "gen": {"true": 0.5}},
118
+                {"lang": {"py": 3.0}, "new": {"x": 4.0}},
119
+            ]
120
+        )
121
+        assert merged == {
122
+            "lang": {"py": 3.0, "rs": 1.5},
123
+            "gen": {"true": 0.5},
124
+            "new": {"x": 4.0},
125
+        }
126
+
127
+
128
+class TestKeepFraction:
129
+    def test_non_positive_fraction_never_keeps(self) -> None:
130
+        assert _keep_fraction("sid", seed=42, fractional=0.0) is False
131
+
132
+    def test_fraction_at_or_above_one_always_keeps(self) -> None:
133
+        assert _keep_fraction("sid", seed=42, fractional=1.0) is True
134
+
135
+
108136
 class TestWeightDistribution:
109137
     def test_empty_rows_empty_dist(self) -> None:
110138
         assert weight_distribution([]) == {}
tests/unit/directives/test_cache.pymodified
206 lines changed — click to load
@@ -31,6 +31,7 @@ class TestOpen:
3131
 
3232
     def test_empty_cache(self, tmp_path: Path) -> None:
3333
         cache = TokenizedCache.open(tmp_path / "c")
34
+        assert cache.root == tmp_path / "c"
3435
         assert cache.entry_count == 0
3536
         assert cache.total_bytes == 0
3637
 
@@ -47,6 +48,94 @@ class TestOpen:
4748
         assert cache.entry_count == 0
4849
         assert any("unreadable" in rec.message for rec in caplog.records)
4950
 
51
+    def test_manifest_version_mismatch_starts_fresh(
52
+        self,
53
+        tmp_path: Path,
54
+        caplog: pytest.LogCaptureFixture,
55
+    ) -> None:
56
+        import json
57
+        import logging
58
+
59
+        root = tmp_path / "c"
60
+        root.mkdir()
61
+        (root / "manifest.json").write_text(
62
+            json.dumps({"version": 999, "entries": {}}),
63
+            encoding="utf-8",
64
+        )
65
+        caplog.set_level(logging.WARNING, logger="dlm.directives.cache")
66
+        cache = TokenizedCache.open(root)
67
+        assert cache.entry_count == 0
68
+        assert "version mismatch" in caplog.text
69
+
70
+    def test_non_mapping_entries_starts_fresh(self, tmp_path: Path) -> None:
71
+        import json
72
+
73
+        root = tmp_path / "c"
74
+        root.mkdir()
75
+        (root / "manifest.json").write_text(
76
+            json.dumps({"version": 1, "entries": []}),
77
+            encoding="utf-8",
78
+        )
79
+        cache = TokenizedCache.open(root)
80
+        assert cache.entry_count == 0
81
+
82
+    def test_malformed_manifest_entry_is_skipped(
83
+        self,
84
+        tmp_path: Path,
85
+        caplog: pytest.LogCaptureFixture,
86
+    ) -> None:
87
+        import json
88
+        import logging
89
+
90
+        root = tmp_path / "c"
91
+        root.mkdir()
92
+        (root / "manifest.json").write_text(
93
+            json.dumps(
94
+                {
95
+                    "version": 1,
96
+                    "entries": {
97
+                        "good": {
98
+                            "size": 4,
99
+                            "last_access_ts": 1.0,
100
+                            "shard": "aa",
101
+                            "filename": "good.npz",
102
+                            "tokenizer_sha": "a" * 64,
103
+                        },
104
+                        "bad": {
105
+                            "size": "not-an-int",
106
+                            "last_access_ts": 1.0,
107
+                            "shard": "bb",
108
+                            "filename": "bad.npz",
109
+                        },
110
+                    },
111
+                }
112
+            ),
113
+            encoding="utf-8",
114
+        )
115
+        caplog.set_level(logging.WARNING, logger="dlm.directives.cache")
116
+        cache = TokenizedCache.open(root)
117
+        assert cache.entry_count == 1
118
+        assert "skipping malformed entry" in caplog.text
119
+
120
+    def test_non_mapping_manifest_entry_is_ignored(self, tmp_path: Path) -> None:
121
+        import json
122
+
123
+        root = tmp_path / "c"
124
+        root.mkdir()
125
+        (root / "manifest.json").write_text(
126
+            json.dumps(
127
+                {
128
+                    "version": 1,
129
+                    "entries": {
130
+                        "bad": "not-a-dict",
131
+                    },
132
+                }
133
+            ),
134
+            encoding="utf-8",
135
+        )
136
+        cache = TokenizedCache.open(root)
137
+        assert cache.entry_count == 0
138
+
50139
 
51140
 class TestGetPut:
52141
     def test_miss_then_hit(self, tmp_path: Path) -> None:
@@ -81,6 +170,68 @@ class TestGetPut:
81170
         assert cache2.entry_count == 1
82171
         assert cache2.get(key) is not None
83172
 
173
+    def test_hit_rate_zero_when_no_lookups(self, tmp_path: Path) -> None:
174
+        cache = TokenizedCache.open(tmp_path / "c")
175
+        assert cache.hit_rate == 0.0
176
+
177
+    def test_corrupt_entry_recovers(
178
+        self,
179
+        tmp_path: Path,
180
+        caplog: pytest.LogCaptureFixture,
181
+    ) -> None:
182
+        import logging
183
+
184
+        cache = TokenizedCache.open(tmp_path / "c")
185
+        key = _key()
186
+        cache.put(key, _tokens(4))
187
+        entry_file = next((tmp_path / "c" / "entries").rglob("*.npz"))
188
+        entry_file.write_bytes(b"not a real npz")
189
+
190
+        caplog.set_level(logging.WARNING, logger="dlm.directives.cache")
191
+        assert cache.get(key) is None
192
+        assert cache.entry_count == 0
193
+        assert "corrupt entry" in caplog.text
194
+
195
+    def test_put_write_failure_drops_entry(
196
+        self,
197
+        tmp_path: Path,
198
+        monkeypatch: pytest.MonkeyPatch,
199
+        caplog: pytest.LogCaptureFixture,
200
+    ) -> None:
201
+        import logging
202
+
203
+        cache = TokenizedCache.open(tmp_path / "c")
204
+
205
+        def _boom(*_args: object, **_kwargs: object) -> None:
206
+            raise OSError("disk full")
207
+
208
+        monkeypatch.setattr("numpy.savez_compressed", _boom)
209
+        caplog.set_level(logging.WARNING, logger="dlm.directives.cache")
210
+        cache.put(_key(), _tokens(4))
211
+
212
+        assert cache.entry_count == 0
213
+        assert list((tmp_path / "c" / "entries").rglob("*.tmp")) == []
214
+        assert "write failed" in caplog.text
215
+
216
+    def test_put_stat_failure_records_zero_sized_entry(
217
+        self,
218
+        tmp_path: Path,
219
+        monkeypatch: pytest.MonkeyPatch,
220
+    ) -> None:
221
+        cache = TokenizedCache.open(tmp_path / "c")
222
+        real_stat = Path.stat
223
+
224
+        def _patched_stat(path: Path, *, follow_symlinks: bool = True) -> object:
225
+            if path.suffix == ".tmp":
226
+                raise OSError("no stat")
227
+            return real_stat(path, follow_symlinks=follow_symlinks)
228
+
229
+        monkeypatch.setattr(Path, "stat", _patched_stat)
230
+        cache.put(_key(), _tokens(4))
231
+
232
+        assert cache.entry_count == 1
233
+        assert cache.total_bytes == 0
234
+
84235
 
85236
 class TestInvalidation:
86237
     def test_different_tokenizer_sha_misses(self, tmp_path: Path) -> None:
@@ -150,6 +301,37 @@ class TestLRUEviction:
150301
         assert cache.get(key_a) is not None
151302
         assert cache.get(key_b) is not None
152303
 
304
+    def test_eviction_stops_once_under_budget(self, tmp_path: Path) -> None:
305
+        import time
306
+
307
+        cache = TokenizedCache.open(tmp_path / "c", max_bytes=10_000_000)
308
+        key_a = _key("aa" * 8)
309
+        key_b = _key("bb" * 8)
310
+        key_c = _key("cc" * 8)
311
+
312
+        cache.put(key_a, _tokens(20))
313
+        size_a = next(
314
+            entry.size
315
+            for entry in cache._manifest.values()
316
+            if entry.filename == key_a.as_filename()
317
+        )
318
+        cache._touched_this_run.clear()
319
+        time.sleep(0.01)
320
+        cache.put(key_b, _tokens(20))
321
+        size_b = next(
322
+            entry.size
323
+            for entry in cache._manifest.values()
324
+            if entry.filename == key_b.as_filename()
325
+        )
326
+        cache._touched_this_run.clear()
327
+
328
+        cache._max_bytes = size_a + size_b
329
+        cache.put(key_c, _tokens(20))
330
+
331
+        assert cache.get(key_a) is None
332
+        assert cache.get(key_b) is not None
333
+        assert cache.get(key_c) is not None
334
+
153335
 
154336
 class TestPruneClear:
155337
     def test_prune_removes_old_entries(self, tmp_path: Path) -> None:
tests/unit/directives/test_cache_key.pymodified
28 lines changed — click to load
@@ -54,11 +54,18 @@ class _FakeBackendTokenizer:
5454
         return self._canonical
5555
 
5656
 
57
+class _BrokenBackendTokenizer:
58
+    def to_str(self) -> str:
59
+        raise RuntimeError("boom")
60
+
61
+
5762
 class _FakeTokenizer:
5863
     """Minimal shape for tokenizer_sha256 — just enough attrs."""
5964
 
6065
     def __init__(self, *, canonical: str | None = None, vocab_size: int = 32000) -> None:
61
-        self.backend_tokenizer = _FakeBackendTokenizer(canonical) if canonical else None
66
+        self.backend_tokenizer: object | None = (
67
+            _FakeBackendTokenizer(canonical) if canonical else None
68
+        )
6269
         self.vocab_size = vocab_size
6370
         self.model_max_length = 2048
6471
         self.pad_token = "<pad>"
@@ -101,3 +108,9 @@ class TestTokenizerSha256:
101108
         tok.backend_tokenizer = _FakeBackendTokenizer('{"v": 2}')
102109
         sha2 = tokenizer_sha256(tok)
103110
         assert sha1 == sha2
111
+
112
+    def test_backend_to_str_failure_falls_back_to_legacy(self) -> None:
113
+        tok = _FakeTokenizer()
114
+        tok.backend_tokenizer = _BrokenBackendTokenizer()
115
+        sha = tokenizer_sha256(tok)
116
+        assert len(sha) == 64
tests/unit/directives/test_discovery.pymodified
58 lines changed — click to load
@@ -27,6 +27,11 @@ def test_no_dlm_dirs_yields_empty(tmp_path: Path) -> None:
2727
     assert discover_configs(tmp_path) == ()
2828
 
2929
 
30
+def test_non_directory_dot_dlm_is_ignored(tmp_path: Path) -> None:
31
+    (tmp_path / ".dlm").write_text("not a directory", encoding="utf-8")
32
+    assert discover_configs(tmp_path) == ()
33
+
34
+
3035
 def test_single_dlm_at_root_with_both_files(tmp_path: Path) -> None:
3136
     (tmp_path / ".dlm").mkdir()
3237
     (tmp_path / ".dlm" / "training.yaml").write_text(
@@ -75,6 +80,18 @@ def test_malformed_yaml_logs_and_continues(
7580
     assert any("invalid YAML" in rec.message for rec in caplog.records)
7681
 
7782
 
83
+def test_invalid_utf8_training_yaml_logs_and_continues(
84
+    tmp_path: Path,
85
+    caplog: pytest.LogCaptureFixture,
86
+) -> None:
87
+    (tmp_path / ".dlm").mkdir()
88
+    (tmp_path / ".dlm" / "training.yaml").write_bytes(b"caf\xe9\n")
89
+    caplog.set_level(logging.WARNING, logger="dlm.directives.discovery")
90
+    configs = discover_configs(tmp_path)
91
+    assert configs[0].config is None
92
+    assert any("not UTF-8" in rec.message for rec in caplog.records)
93
+
94
+
7895
 def test_schema_violation_logs_and_continues(
7996
     tmp_path: Path, caplog: pytest.LogCaptureFixture
8097
 ) -> None:
@@ -97,6 +114,14 @@ def test_training_yaml_non_mapping_top_level(
97114
     assert any("must be a mapping" in rec.message for rec in caplog.records)
98115
 
99116
 
117
+def test_training_yaml_null_top_level_coerces_to_empty_config(tmp_path: Path) -> None:
118
+    (tmp_path / ".dlm").mkdir()
119
+    (tmp_path / ".dlm" / "training.yaml").write_text("null\n", encoding="utf-8")
120
+    configs = discover_configs(tmp_path)
121
+    assert configs[0].config is not None
122
+    assert configs[0].config.dlm_training_version == 1
123
+
124
+
100125
 def test_both_files_coexist(tmp_path: Path) -> None:
101126
     (tmp_path / ".dlm").mkdir()
102127
     (tmp_path / ".dlm" / "training.yaml").write_text("dlm_training_version: 1\nexclude: ['a']\n")
@@ -105,3 +130,15 @@ def test_both_files_coexist(tmp_path: Path) -> None:
105130
     assert c.config is not None
106131
     assert c.config.exclude == ("a",)
107132
     assert len(c.ignore_rules) == 1
133
+
134
+
135
+def test_invalid_utf8_ignore_logs_and_continues(
136
+    tmp_path: Path,
137
+    caplog: pytest.LogCaptureFixture,
138
+) -> None:
139
+    (tmp_path / ".dlm").mkdir()
140
+    (tmp_path / ".dlm" / "ignore").write_bytes(b"bad-\xff\n")
141
+    caplog.set_level(logging.WARNING, logger="dlm.directives.discovery")
142
+    configs = discover_configs(tmp_path)
143
+    assert configs[0].ignore_rules == ()
144
+    assert any("not UTF-8" in rec.message for rec in caplog.records)
tests/unit/directives/test_expand.pymodified
81 lines changed — click to load
@@ -7,14 +7,17 @@ directives fast-path.
77
 
88
 from __future__ import annotations
99
 
10
+import os
1011
 from pathlib import Path
1112
 
1213
 import pytest
1314
 
1415
 from dlm.directives import expand_sources
1516
 from dlm.directives.errors import DirectivePathError, DirectivePolicyError
17
+from dlm.directives.expand import _iter_candidates
1618
 from dlm.doc.parser import parse_text
1719
 from dlm.doc.sections import SectionType
20
+from dlm.store.blobs import BlobStore
1821
 
1922
 _VALID_ULID = "01ABCDEFGHJKMNPQRSTVWXYZ00"
2023
 
@@ -180,3 +183,64 @@ def test_single_file_directive(tmp_path: Path) -> None:
180183
     result = expand_sources(parsed, base_path=tmp_path)  # type: ignore[arg-type]
181184
     assert len(result.sections) == 1
182185
     assert result.sections[0].content.startswith("# source: notes.md")
186
+
187
+
188
+def test_stat_failure_skips_file(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
189
+    src = tmp_path / "src"
190
+    src.mkdir()
191
+    target = src / "a.py"
192
+    target.write_text("print(1)\n", encoding="utf-8")
193
+    body = "  sources:\n    - path: src\n      include: ['**/*.py']\n"
194
+    parsed, _ = _make_parsed(body, tmp_path)
195
+    real_stat = Path.stat
196
+    seen_target = 0
197
+
198
+    def _patched_stat(path: Path, *, follow_symlinks: bool = True) -> os.stat_result:
199
+        nonlocal seen_target
200
+        if path == target:
201
+            seen_target += 1
202
+            if seen_target >= 2:
203
+                raise OSError("no stat")
204
+        return real_stat(path, follow_symlinks=follow_symlinks)
205
+
206
+    monkeypatch.setattr(Path, "stat", _patched_stat)
207
+    result = expand_sources(parsed, base_path=tmp_path)  # type: ignore[arg-type]
208
+    assert result.sections == ()
209
+
210
+
211
+def test_read_bytes_failure_skips_file(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
212
+    src = tmp_path / "src"
213
+    src.mkdir()
214
+    target = src / "a.py"
215
+    target.write_text("print(1)\n", encoding="utf-8")
216
+    body = "  sources:\n    - path: src\n      include: ['**/*.py']\n"
217
+    parsed, _ = _make_parsed(body, tmp_path)
218
+    real_read_bytes = Path.read_bytes
219
+
220
+    def _patched_read_bytes(path: Path, *args: object, **kwargs: object) -> bytes:
221
+        if path == target:
222
+            raise OSError("no read")
223
+        return real_read_bytes(path, *args, **kwargs)
224
+
225
+    monkeypatch.setattr(Path, "read_bytes", _patched_read_bytes)
226
+    result = expand_sources(parsed, base_path=tmp_path)  # type: ignore[arg-type]
227
+    assert result.sections == ()
228
+
229
+
230
+def test_audio_transcript_unreadable_skips_audio(tmp_path: Path) -> None:
231
+    corpus = tmp_path / "corpus"
232
+    corpus.mkdir()
233
+    (corpus / "clip.wav").write_bytes(b"RIFF....fake wav")
234
+    (corpus / "clip.txt").write_bytes(b"bad-\xff\n")
235
+    parsed, _ = _make_parsed(
236
+        '  sources:\n    - path: corpus\n      include: ["**/*.wav"]\n',
237
+        tmp_path,
238
+    )
239
+    blob_store = BlobStore(tmp_path / "blobs")
240
+    result = expand_sources(parsed, base_path=tmp_path, blob_store=blob_store)  # type: ignore[arg-type]
241
+    assert result.sections == ()
242
+    assert result.provenance[0].skipped_audio_no_transcript == 1
243
+
244
+
245
+def test_iter_candidates_non_file_non_dir_yields_nothing(tmp_path: Path) -> None:
246
+    assert list(_iter_candidates(tmp_path / "missing")) == []
tests/unit/directives/test_ignore_parser.pymodified
13 lines changed — click to load
@@ -68,6 +68,13 @@ def test_parse_bare_slash_skipped(caplog: pytest.LogCaptureFixture) -> None:
6868
     assert any("bare '/'" in rec.message for rec in caplog.records)
6969
 
7070
 
71
+def test_parse_pattern_reduced_to_empty_skipped(caplog: pytest.LogCaptureFixture) -> None:
72
+    caplog.set_level(logging.WARNING, logger="dlm.directives.ignore_parser")
73
+    rules = parse_ignore_file("//\n")
74
+    assert rules == ()
75
+    assert any("pattern reduced to empty" in rec.message for rec in caplog.records)
76
+
77
+
7178
 # ---- matches ---------------------------------------------------------------
7279
 
7380
 
tests/unit/directives/test_merge.pymodified
32 lines changed — click to load
@@ -119,6 +119,19 @@ def test_training_yaml_exclude_blocks_file(tmp_path: Path) -> None:
119119
     )
120120
 
121121
 
122
+def test_parent_directive_exclude_blocks_file(tmp_path: Path) -> None:
123
+    _write(tmp_path / "src" / "main.py", "x")
124
+    directive = _directive(tmp_path, exclude=("**/*.py",))
125
+    configs = discover_configs(tmp_path)
126
+    eff = effective_config_for(
127
+        tmp_path / "src" / "main.py",
128
+        source_root=tmp_path,
129
+        discovered=configs,
130
+        parent_directive=directive,
131
+    )
132
+    assert eff.included is False
133
+
134
+
122135
 # ---- .dlm/ignore negation --------------------------------------------------
123136
 
124137
 
@@ -266,3 +279,13 @@ def test_metadata_empty_when_no_training_yaml(tmp_path: Path) -> None:
266279
         parent_directive=directive,
267280
     )
268281
     assert dict(eff.tags) == {}
282
+
283
+
284
+def test_relpath_falls_back_to_filename_for_non_ancestor_anchor(tmp_path: Path) -> None:
285
+    from dlm.directives.merge import _relpath
286
+
287
+    file_path = tmp_path / "src" / "main.py"
288
+    _write(file_path, "x")
289
+    other_anchor = tmp_path / "elsewhere"
290
+    other_anchor.mkdir()
291
+    assert _relpath(file_path, other_anchor) == "main.py"
tests/unit/doc/test_parser_roundtrip.pymodified
35 lines changed — click to load
@@ -197,6 +197,35 @@ class TestFenceGrammar:
197197
             SectionType.PREFERENCE,
198198
         ]
199199
 
200
+    def test_whitespace_only_trailing_prose_is_elided(self) -> None:
201
+        text = (
202
+            f"---\ndlm_id: {VALID_ULID}\nbase_model: smollm2-135m\n---\n\n"
203
+            "::instruction::\n"
204
+            "### Q\n"
205
+            "q\n"
206
+            "### A\n"
207
+            "a\n"
208
+            "\n"
209
+            "   \n"
210
+            "\t\n"
211
+        )
212
+        parsed = parse_text(text)
213
+        assert [section.type for section in parsed.sections] == [SectionType.INSTRUCTION]
214
+
215
+    def test_whitespace_only_prose_before_first_fence_is_elided(self) -> None:
216
+        text = (
217
+            f"---\ndlm_id: {VALID_ULID}\nbase_model: smollm2-135m\n---\n\n"
218
+            "   \n"
219
+            "\t\n"
220
+            "::instruction::\n"
221
+            "### Q\n"
222
+            "q\n"
223
+            "### A\n"
224
+            "a\n"
225
+        )
226
+        parsed = parse_text(text)
227
+        assert [section.type for section in parsed.sections] == [SectionType.INSTRUCTION]
228
+
200229
     def test_unknown_attribute_fence_raises(self) -> None:
201230
         text = (
202231
             f"---\ndlm_id: {VALID_ULID}\nbase_model: smollm2-135m\n---\n\n"
tests/unit/doc/test_serializer_edges.pymodified
18 lines changed — click to load
@@ -173,6 +173,18 @@ class TestSerializeTrailingNewline:
173173
         assert out.endswith("\n")
174174
         assert not out.endswith("\n\n")
175175
 
176
+    def test_serializer_adds_newline_when_section_render_omits_it(
177
+        self, monkeypatch: pytest.MonkeyPatch
178
+    ) -> None:
179
+        fm = DlmFrontmatter(dlm_id=VALID_ULID, base_model="smollm2-135m")
180
+        parsed = ParsedDlm(
181
+            frontmatter=fm,
182
+            sections=(Section(SectionType.PROSE, "content"),),
183
+        )
184
+        monkeypatch.setattr("dlm.doc.serializer._serialize_frontmatter", lambda _fm: "---\n---")
185
+        monkeypatch.setattr("dlm.doc.serializer._serialize_section", lambda _section: "body")
186
+        assert serialize(parsed).endswith("\n")
187
+
176188
 
177189
 class TestFrontmatterExplicitTargetModulesList:
178190
     """Ensures the list branch in the nested-mapping emitter is exercised."""
tests/unit/eval/test_mode_split.pymodified
59 lines changed — click to load
@@ -8,7 +8,7 @@ from unittest.mock import MagicMock
88
 
99
 import pytest
1010
 
11
-from dlm.eval.mode_split import compute_val_loss_by_mode
11
+from dlm.eval.mode_split import _safe_eval_loss, compute_val_loss_by_mode
1212
 
1313
 
1414
 class _FakeDataset:
@@ -68,6 +68,14 @@ class TestEmptyOrMissing:
6868
         assert compute_val_loss_by_mode(trainer, _FakeDataset([])) == (None, None)
6969
         trainer.evaluate.assert_not_called()
7070
 
71
+    def test_non_sized_dataset_returns_both_none(self) -> None:
72
+        trainer = MagicMock()
73
+        assert compute_val_loss_by_mode(trainer, _NonSizedDataset([{"text": "prose"}])) == (
74
+            None,
75
+            None,
76
+        )
77
+        trainer.evaluate.assert_not_called()
78
+
7179
 
7280
 class TestModeClassification:
7381
     def test_only_cpt_rows(self) -> None:
@@ -149,6 +157,14 @@ class TestEvalFailures:
149157
         assert cpt is None
150158
         assert sft is None
151159
 
160
+    def test_non_numeric_eval_loss_yields_none(self) -> None:
161
+        trainer = MagicMock()
162
+        trainer.evaluate.return_value = {"eval_loss": object()}
163
+        val = _FakeDataset([{"text": "a"}])
164
+        cpt, sft = compute_val_loss_by_mode(trainer, val)
165
+        assert cpt is None
166
+        assert sft is None
167
+
152168
     def test_select_failure_yields_none(
153169
         self,
154170
         caplog: pytest.LogCaptureFixture,
@@ -174,3 +190,23 @@ class _NoSelectDataset:
174190
 
175191
     def __iter__(self):  # type: ignore[no-untyped-def]
176192
         return iter(self._rows)
193
+
194
+
195
+class _NonSizedDataset:
196
+    def __init__(self, rows: list[dict[str, Any]]) -> None:
197
+        self._rows = rows
198
+
199
+    def __iter__(self):  # type: ignore[no-untyped-def]
200
+        return iter(self._rows)
201
+
202
+
203
+def test_safe_eval_loss_value_error_yields_none(
204
+    caplog: pytest.LogCaptureFixture,
205
+) -> None:
206
+    caplog.set_level(logging.WARNING, logger="dlm.eval.mode_split")
207
+    trainer = MagicMock()
208
+    trainer.evaluate.side_effect = ValueError("bad eval")
209
+    val = _FakeDataset([{"text": "a"}])
210
+
211
+    assert _safe_eval_loss(trainer, val, [0], mode="cpt") is None
212
+    assert "val-loss split skipped cpt evaluation" in caplog.text
tests/unit/eval/test_probes.pymodified
56 lines changed — click to load
@@ -7,8 +7,9 @@ import logging
77
 
88
 import pytest
99
 
10
+from dlm.data.instruction_parser import QAPair
1011
 from dlm.doc.sections import Section, SectionType
11
-from dlm.eval.probes import Probe, extract_probes
12
+from dlm.eval.probes import Probe, _auto_sample_probes, _normalize_probe_markers, extract_probes
1213
 
1314
 
1415
 class TestExplicitProbes:
@@ -42,6 +43,15 @@ class TestExplicitProbes:
4243
         # Auto-sampled fills the remainder.
4344
         assert any(p.prompt == "not-probe" for p in probes)
4445
 
46
+    def test_probe_header_preserves_blank_lines_before_prompt(self) -> None:
47
+        body = "### Q !probe\n\n\nWhat is Paris?\n### A\nCapital of France."
48
+        s = Section(type=SectionType.INSTRUCTION, content=body)
49
+
50
+        probes = extract_probes([s], k=1)
51
+
52
+        assert len(probes) == 1
53
+        assert probes[0].prompt == "What is Paris?"
54
+
4555
 
4656
 class TestAutoSample:
4757
     def test_auto_sample_when_no_explicit(self) -> None:
@@ -74,6 +84,22 @@ class TestAutoSample:
7484
         s = Section(type=SectionType.INSTRUCTION, content=body)
7585
         assert extract_probes([s], k=0) == []
7686
 
87
+    def test_auto_sample_internal_k_zero_returns_empty(self) -> None:
88
+        assert _auto_sample_probes([], k=0, seed=0, exclude=set(), parsed_pairs={}) == []
89
+
90
+    def test_auto_sample_exclude_skips_seen_prompts(self) -> None:
91
+        section = Section(type=SectionType.INSTRUCTION, content="### Q\nQ1?\n### A\nA1")
92
+
93
+        probes = _auto_sample_probes(
94
+            [section],
95
+            k=1,
96
+            seed=0,
97
+            exclude={"Q1?"},
98
+            parsed_pairs={section.section_id: [QAPair(question="Q1?", answer="A1")]},
99
+        )
100
+
101
+        assert probes == []
102
+
77103
     def test_malformed_instruction_logs_warning_once(
78104
         self,
79105
         caplog: pytest.LogCaptureFixture,
@@ -91,3 +117,9 @@ class TestProbeDataclass:
91117
         p = Probe(prompt="hi", reference="hello")
92118
         with pytest.raises(dataclasses.FrozenInstanceError):
93119
             p.prompt = "other"  # type: ignore[misc]
120
+
121
+
122
+def test_normalize_probe_markers_keeps_non_probe_body() -> None:
123
+    body = "### Q\nplain\n### A\nanswer"
124
+
125
+    assert _normalize_probe_markers(body) == body
tests/unit/export/ollama/test_modelfile.pymodified
11 lines changed — click to load
@@ -135,6 +135,11 @@ class TestShape:
135135
         text = render_modelfile(_ctx(tmp_path))
136136
         assert 'LICENSE "Apache-2.0"' in text
137137
 
138
+    def test_license_line_omitted_when_spec_has_no_spdx(self, tmp_path: Path) -> None:
139
+        spec = _SPEC.model_copy(update={"license_spdx": ""})
140
+        text = render_modelfile(_ctx(tmp_path, spec=spec))
141
+        assert "LICENSE " not in text
142
+
138143
     def test_trailing_newline(self, tmp_path: Path) -> None:
139144
         assert render_modelfile(_ctx(tmp_path)).endswith("\n")
140145
 
tests/unit/export/targets/test_llama_server_argv.pymodified
140 lines changed — click to load
@@ -6,9 +6,27 @@ import json
66
 from datetime import datetime
77
 from pathlib import Path
88
 
9
+import pytest
10
+
911
 from dlm.base_models import BASE_MODELS
12
+from dlm.export.dispatch import DispatchResult
13
+from dlm.export.errors import ExportError
1014
 from dlm.export.manifest import ExportManifest, load_export_manifest
11
-from dlm.export.targets.llama_server import prepare_llama_server_export
15
+from dlm.export.targets.base import TargetResult
16
+from dlm.export.targets.llama_server import (
17
+    LLAMA_SERVER_TARGET,
18
+    _find_artifact,
19
+    _optional_int_extra,
20
+    _optional_path_extra,
21
+    _optional_prepared_path,
22
+    _read_chat_template,
23
+    _require_path_extra,
24
+    _require_prepared_int,
25
+    _require_prepared_path,
26
+    _require_spec_extra,
27
+    _script_dir_arg,
28
+    prepare_llama_server_export,
29
+)
1230
 
1331
 
1432
 def _vendor_tree(tmp_path: Path) -> Path:
@@ -117,3 +135,112 @@ class TestPrepareLlamaServerExport:
117135
         script = prepared.launch_script_path.read_text(encoding="utf-8")
118136
         assert "--lora " not in script
119137
         assert "--ctx-size 512" in script
138
+
139
+
140
+class TestLlamaServerHelpers:
141
+    def test_read_chat_template_rejects_invalid_json(self, tmp_path: Path) -> None:
142
+        adapter_dir = tmp_path / "adapter"
143
+        adapter_dir.mkdir()
144
+        (adapter_dir / "tokenizer_config.json").write_text("not json {{{", encoding="utf-8")
145
+
146
+        with pytest.raises(ExportError, match="cannot load chat template"):
147
+            _read_chat_template(adapter_dir)
148
+
149
+    def test_read_chat_template_rejects_blank_template(self, tmp_path: Path) -> None:
150
+        adapter_dir = tmp_path / "adapter"
151
+        adapter_dir.mkdir()
152
+        (adapter_dir / "tokenizer_config.json").write_text(
153
+            json.dumps({"chat_template": "   "}),
154
+            encoding="utf-8",
155
+        )
156
+
157
+        with pytest.raises(ExportError, match="has no non-empty chat_template"):
158
+            _read_chat_template(adapter_dir)
159
+
160
+    def test_find_artifact_missing_prefix_raises(self, tmp_path: Path) -> None:
161
+        with pytest.raises(ExportError, match="missing export artifact with prefix"):
162
+            _find_artifact([tmp_path / "adapter.gguf"], prefix="base.")
163
+
164
+    def test_script_dir_arg_requires_path(self) -> None:
165
+        with pytest.raises(ExportError, match="missing a required path"):
166
+            _script_dir_arg(None)
167
+
168
+    def test_dispatch_extra_validators_raise_on_wrong_types(self, tmp_path: Path) -> None:
169
+        ctx = DispatchResult(
170
+            export_dir=tmp_path,
171
+            manifest_path=tmp_path / "export_manifest.json",
172
+            artifacts=[],
173
+            banner_lines=[],
174
+            extras={
175
+                "adapter_dir": "bad",
176
+                "training_sequence_len": "bad",
177
+                "spec": "bad",
178
+                "vendor_override": "bad",
179
+            },
180
+        )
181
+
182
+        with pytest.raises(ExportError, match="missing Path extra 'adapter_dir'"):
183
+            _require_path_extra(ctx, "adapter_dir")
184
+        with pytest.raises(ExportError, match="must be an int"):
185
+            _optional_int_extra(ctx, "training_sequence_len")
186
+        with pytest.raises(ExportError, match="missing BaseModelSpec extra 'spec'"):
187
+            _require_spec_extra(ctx, "spec")
188
+        with pytest.raises(ExportError, match="must be a Path"):
189
+            _optional_path_extra(ctx, "vendor_override")
190
+
191
+        empty_ctx = DispatchResult(
192
+            export_dir=tmp_path,
193
+            manifest_path=tmp_path / "export_manifest.json",
194
+            artifacts=[],
195
+            banner_lines=[],
196
+            extras={},
197
+        )
198
+        assert _optional_path_extra(empty_ctx, "vendor_override") is None
199
+        assert _optional_int_extra(empty_ctx, "training_sequence_len") is None
200
+
201
+    def test_prepared_extra_validators_raise_on_wrong_types(self, tmp_path: Path) -> None:
202
+        prepared = TargetResult(
203
+            name="llama-server",
204
+            export_dir=tmp_path,
205
+            manifest_path=tmp_path / "export_manifest.json",
206
+            config_path=tmp_path / "chat-template.jinja",
207
+            extras={
208
+                "model_path": tmp_path / "base.gguf",
209
+                "adapter_gguf_path": "bad",
210
+                "context_length": 512,
211
+            },
212
+        )
213
+
214
+        assert _require_prepared_path(prepared, "model_path") == tmp_path / "base.gguf"
215
+        with pytest.raises(ExportError, match="must be a Path"):
216
+            _optional_prepared_path(prepared, "adapter_gguf_path")
217
+        with pytest.raises(ExportError, match="must be a Path"):
218
+            LLAMA_SERVER_TARGET.launch_command(prepared)
219
+
220
+        bad_int = TargetResult(
221
+            name="llama-server",
222
+            export_dir=tmp_path,
223
+            manifest_path=tmp_path / "export_manifest.json",
224
+            config_path=tmp_path / "chat-template.jinja",
225
+            extras={
226
+                "model_path": tmp_path / "base.gguf",
227
+                "context_length": "bad",
228
+            },
229
+        )
230
+        with pytest.raises(ExportError, match="missing int extra 'context_length'"):
231
+            _require_prepared_int(bad_int, "context_length")
232
+
233
+    def test_smoke_failure_from_runtime_command_is_reported(self, tmp_path: Path) -> None:
234
+        prepared = TargetResult(
235
+            name="llama-server",
236
+            export_dir=tmp_path,
237
+            manifest_path=tmp_path / "export_manifest.json",
238
+            extras={"model_path": "bad"},
239
+            config_path=tmp_path / "chat-template.jinja",
240
+        )
241
+
242
+        result = LLAMA_SERVER_TARGET.smoke_test(prepared)
243
+
244
+        assert result.attempted is True
245
+        assert result.ok is False
246
+        assert "missing Path extra 'model_path'" in result.detail
tests/unit/export/targets/test_mlx_serve_argv.pymodified
116 lines changed — click to load
@@ -15,6 +15,9 @@ from dlm.export.targets.mlx_serve import (
1515
     MLX_SERVE_TARGET,
1616
     _quote_script_arg,
1717
     _require_prepared_int,
18
+    _require_prepared_path,
19
+    _require_prepared_str,
20
+    _version_from_dir_name,
1821
     finalize_mlx_serve_export,
1922
     prepare_mlx_serve_export,
2023
 )
@@ -115,6 +118,30 @@ class TestPrepareMlxServeExport:
115118
         assert store_manifest.exports[-1].quant == "hf"
116119
         assert store_manifest.exports[-1].smoke_output_first_line == "hello from mlx"
117120
 
121
+    def test_prepare_replaces_stale_staged_adapter_dir(
122
+        self, tmp_path: Path, monkeypatch: object
123
+    ) -> None:
124
+        store = _setup_flat_store(tmp_path)
125
+        export_dir = store.exports / "mlx-serve"
126
+        stale_dir = export_dir / "adapter"
127
+        stale_dir.mkdir(parents=True)
128
+        (stale_dir / "stale.txt").write_text("stale", encoding="utf-8")
129
+        monkeypatch.setattr("dlm.export.targets.mlx_serve.is_apple_silicon", lambda: True)
130
+        monkeypatch.setattr("dlm.export.targets.mlx_serve.mlx_available", lambda: True)
131
+        monkeypatch.setattr("dlm.export.targets.mlx_serve.stage_mlx_adapter_dir", _fake_stage_mlx)
132
+
133
+        prepared = prepare_mlx_serve_export(
134
+            store=store,
135
+            spec=_SPEC,
136
+            adapter_name=None,
137
+            adapter_path_override=None,
138
+            declared_adapter_names=None,
139
+        )
140
+
141
+        assert prepared.launch_script_path is not None
142
+        assert not (prepared.export_dir / "adapter" / "stale.txt").exists()
143
+        assert (prepared.export_dir / "adapter" / "adapters.safetensors").exists()
144
+
118145
     def test_multi_adapter_export_requires_explicit_selection(
119146
         self, tmp_path: Path, monkeypatch: object
120147
     ) -> None:
@@ -174,6 +201,58 @@ class TestPrepareMlxServeExport:
174201
                 declared_adapter_names=None,
175202
             )
176203
 
204
+    def test_named_adapter_export_uses_named_dir(self, tmp_path: Path, monkeypatch: object) -> None:
205
+        store = _setup_named_store(tmp_path)
206
+        monkeypatch.setattr("dlm.export.targets.mlx_serve.is_apple_silicon", lambda: True)
207
+        monkeypatch.setattr("dlm.export.targets.mlx_serve.mlx_available", lambda: True)
208
+        monkeypatch.setattr("dlm.export.targets.mlx_serve.stage_mlx_adapter_dir", _fake_stage_mlx)
209
+
210
+        prepared = prepare_mlx_serve_export(
211
+            store=store,
212
+            spec=_SPEC,
213
+            adapter_name="knowledge",
214
+            adapter_path_override=None,
215
+            declared_adapter_names=None,
216
+        )
217
+
218
+        assert str(prepared.extras["adapter_dir"]).endswith("knowledge")
219
+        assert prepared.extras["adapter_version"] == 2
220
+
221
+    def test_missing_adapter_override_raises(self, tmp_path: Path, monkeypatch: object) -> None:
222
+        store = _setup_flat_store(tmp_path)
223
+        monkeypatch.setattr("dlm.export.targets.mlx_serve.is_apple_silicon", lambda: True)
224
+        monkeypatch.setattr("dlm.export.targets.mlx_serve.mlx_available", lambda: True)
225
+
226
+        with pytest.raises(ExportError, match="adapter_path_override .* does not exist"):
227
+            prepare_mlx_serve_export(
228
+                store=store,
229
+                spec=_SPEC,
230
+                adapter_name=None,
231
+                adapter_path_override=tmp_path / "missing",
232
+                declared_adapter_names=None,
233
+            )
234
+
235
+    def test_existing_adapter_override_uses_mixed_dir(
236
+        self, tmp_path: Path, monkeypatch: object
237
+    ) -> None:
238
+        store = _setup_flat_store(tmp_path)
239
+        override = tmp_path / "custom-adapter"
240
+        _write_adapter(override)
241
+        monkeypatch.setattr("dlm.export.targets.mlx_serve.is_apple_silicon", lambda: True)
242
+        monkeypatch.setattr("dlm.export.targets.mlx_serve.mlx_available", lambda: True)
243
+        monkeypatch.setattr("dlm.export.targets.mlx_serve.stage_mlx_adapter_dir", _fake_stage_mlx)
244
+
245
+        prepared = prepare_mlx_serve_export(
246
+            store=store,
247
+            spec=_SPEC,
248
+            adapter_name=None,
249
+            adapter_path_override=override,
250
+            declared_adapter_names=None,
251
+        )
252
+
253
+        assert str(prepared.extras["adapter_dir"]).endswith("mixed")
254
+        assert prepared.extras["adapter_version"] == 1
255
+
177256
     def test_missing_default_adapter_raises(self, tmp_path: Path, monkeypatch: object) -> None:
178257
         store = for_dlm("01EMPTYMLX", home=tmp_path)
179258
         store.ensure_layout()
@@ -262,3 +341,19 @@ class TestMlxServeHelpers:
262341
         )
263342
         with pytest.raises(ExportError, match="missing int extra"):
264343
             _require_prepared_int(prepared, "adapter_version")
344
+
345
+    def test_string_and_path_validation(self) -> None:
346
+        prepared = TargetResult(
347
+            name="mlx-serve",
348
+            export_dir=Path("/tmp/export"),
349
+            manifest_path=Path("/tmp/export/export_manifest.json"),
350
+            extras={"model": "", "adapter_dir": "bad"},
351
+        )
352
+
353
+        with pytest.raises(ExportError, match="missing string extra"):
354
+            _require_prepared_str(prepared, "model")
355
+        with pytest.raises(ExportError, match="missing Path extra"):
356
+            _require_prepared_path(prepared, "adapter_dir")
357
+
358
+    def test_version_from_dir_name_defaults_for_non_version_dirs(self) -> None:
359
+        assert _version_from_dir_name(Path("custom-adapter")) == 1
tests/unit/export/targets/test_vllm_argv.pymodified
77 lines changed — click to load
@@ -16,12 +16,14 @@ from dlm.export.targets.vllm import (
1616
     VLLM_TARGET,
1717
     LoraModule,
1818
     _default_runtime_env,
19
+    _machine,
1920
     _optional_prepared_int,
2021
     _render_launch_script,
2122
     _require_module_specs,
2223
     _require_prepared_int,
2324
     _require_prepared_str,
2425
     _runtime_env,
26
+    _sys_platform,
2527
     finalize_vllm_export,
2628
     prepare_vllm_export,
2729
 )
@@ -128,6 +130,27 @@ class TestPrepareVllmExport:
128130
         assert store_manifest.exports[-1].quant == "hf"
129131
         assert store_manifest.exports[-1].smoke_output_first_line == "hello from vllm"
130132
 
133
+    def test_prepare_replaces_stale_adapters_dir(self, tmp_path: Path) -> None:
134
+        store = _setup_flat_store(tmp_path)
135
+        export_dir = store.exports / "vllm"
136
+        stale_dir = export_dir / "adapters"
137
+        stale_dir.mkdir(parents=True)
138
+        (stale_dir / "stale.txt").write_text("stale", encoding="utf-8")
139
+
140
+        prepared = prepare_vllm_export(
141
+            store=store,
142
+            spec=_SPEC,
143
+            served_model_name="dlm-flat",
144
+            training_sequence_len=2048,
145
+            adapter_name=None,
146
+            adapter_path_override=None,
147
+            declared_adapter_names=None,
148
+        )
149
+
150
+        assert prepared.launch_script_path is not None
151
+        assert not (prepared.export_dir / "adapters" / "stale.txt").exists()
152
+        assert (prepared.export_dir / "adapters" / "adapter" / "adapter_model.safetensors").exists()
153
+
131154
     def test_multi_adapter_export_includes_all_named_modules(self, tmp_path: Path) -> None:
132155
         store = _setup_named_store(tmp_path)
133156
 
@@ -244,6 +267,26 @@ class TestPrepareVllmExport:
244267
                 declared_adapter_names=None,
245268
             )
246269
 
270
+    def test_named_adapter_export_stages_only_named_module(self, tmp_path: Path) -> None:
271
+        store = _setup_named_store(tmp_path)
272
+
273
+        prepared = prepare_vllm_export(
274
+            store=store,
275
+            spec=_SPEC,
276
+            served_model_name="dlm-knowledge",
277
+            training_sequence_len=2048,
278
+            adapter_name="knowledge",
279
+            adapter_path_override=None,
280
+            declared_adapter_names=None,
281
+        )
282
+
283
+        config = json.loads(
284
+            (prepared.export_dir / VLLM_CONFIG_FILENAME).read_text(encoding="utf-8")
285
+        )
286
+        assert config["lora_modules"] == [
287
+            {"adapter_version": 2, "name": "knowledge", "path": "adapters/knowledge"}
288
+        ]
289
+
247290
     def test_missing_default_adapter_raises(self, tmp_path: Path) -> None:
248291
         store = for_dlm("01EMPTYVLLM", home=tmp_path)
249292
         store.ensure_layout()
@@ -355,6 +398,10 @@ class TestVllmSmoke:
355398
 
356399
 
357400
 class TestVllmHelpers:
401
+    def test_platform_helpers_return_strings(self) -> None:
402
+        assert isinstance(_sys_platform(), str)
403
+        assert isinstance(_machine(), str)
404
+
358405
     def test_default_runtime_env_is_empty_off_apple_silicon(
359406
         self, monkeypatch: pytest.MonkeyPatch
360407
     ) -> None:
tests/unit/export/test_arch_probe.pymodified
29 lines changed — click to load
@@ -19,6 +19,7 @@ from pathlib import Path
1919
 
2020
 import pytest
2121
 
22
+import dlm.export.arch_probe as arch_probe
2223
 from dlm.export.arch_probe import (
2324
     ArchProbeResult,
2425
     SupportLevel,
@@ -173,6 +174,22 @@ class TestGrammarEdgeCases:
173174
         # "GemmaForCausalLM" (without the 3) isn't registered.
174175
         assert result.support is SupportLevel.UNSUPPORTED
175176
 
177
+    def test_decorator_without_following_class_is_ignored(self, tmp_path: Path) -> None:
178
+        root = _fixture_llama_cpp(
179
+            tmp_path,
180
+            '@ModelBase.register("FooForCausalLM")\n# no class follows\n',
181
+        )
182
+        result = probe_gguf_arch("FooForCausalLM", llama_cpp_root=root)
183
+        assert result.support is SupportLevel.UNSUPPORTED
184
+
185
+    def test_unextractable_class_name_is_ignored(self, monkeypatch: pytest.MonkeyPatch) -> None:
186
+        text = '@ModelBase.register("FooForCausalLM")\nclass FooModel(TextModel):\n    pass\n'
187
+        monkeypatch.setattr(arch_probe, "_extract_class_name", lambda _text, _start: None)
188
+        assert arch_probe._find_arch_bindings(text, "FooForCausalLM") == []
189
+
190
+    def test_extract_class_name_returns_none_without_open_paren(self) -> None:
191
+        assert arch_probe._extract_class_name("class FooModel:\n    pass\n", 0) is None
192
+
176193
 
177194
 class TestMemoization:
178195
     def test_repeat_calls_hit_cache(self, tmp_path: Path) -> None:
tests/unit/export/test_audio_snapshot.pymodified
105 lines changed — click to load
@@ -12,6 +12,7 @@ Mirrors `test_vl_snapshot.py`. Covers:
1212
 
1313
 from __future__ import annotations
1414
 
15
+import json
1516
 from pathlib import Path
1617
 
1718
 import pytest
@@ -97,6 +98,12 @@ class TestRefusals:
9798
         with pytest.raises(ExportError, match="only audio-language bases"):
9899
             run_audio_snapshot_export(populated_store, _text_spec())
99100
 
101
+    def test_missing_audio_preprocessor_plan_refused(self, populated_store) -> None:
102
+        spec = _audio_spec()
103
+        object.__setattr__(spec, "audio_preprocessor_plan", None)
104
+        with pytest.raises(ExportError, match="no audio_preprocessor_plan"):
105
+            run_audio_snapshot_export(populated_store, spec)
106
+
100107
     def test_missing_adapter_refused(self, tmp_path: Path) -> None:
101108
         store = for_dlm(_VALID_ULID, home=tmp_path)
102109
         store.ensure_layout()
@@ -142,6 +149,72 @@ class TestSnapshotLayout:
142149
         assert result.export_dir.name != "hf-snapshot"
143150
         assert result.export_dir.name == AUDIO_SNAPSHOT_SUBDIR
144151
 
152
+    def test_named_adapter_export_uses_named_current_pointer(self, populated_store) -> None:
153
+        named = populated_store.adapter_version_for("podcast", 5)
154
+        named.mkdir(parents=True, exist_ok=True)
155
+        (named / "adapter_config.json").write_text('{"r": 32}', encoding="utf-8")
156
+        (named / "adapter_model.safetensors").write_bytes(b"named audio bytes")
157
+        populated_store.set_current_adapter_for("podcast", named)
158
+
159
+        result = run_audio_snapshot_export(
160
+            populated_store,
161
+            _audio_spec(),
162
+            adapter_name="podcast",
163
+        )
164
+
165
+        assert (
166
+            result.adapter_dir / "adapter_model.safetensors"
167
+        ).read_bytes() == b"named audio bytes"
168
+        manifest = load_audio_snapshot_manifest(result.export_dir)
169
+        assert manifest.adapter_version == 5
170
+        assert manifest.adapter_name == "podcast"
171
+
172
+    def test_adapter_override_uses_provided_dir(self, populated_store, tmp_path: Path) -> None:
173
+        override = tmp_path / "merged-adapter"
174
+        override.mkdir()
175
+        (override / "adapter_model.safetensors").write_bytes(b"override audio bytes")
176
+
177
+        result = run_audio_snapshot_export(
178
+            populated_store,
179
+            _audio_spec(),
180
+            adapter_path_override=override,
181
+        )
182
+
183
+        assert (
184
+            result.adapter_dir / "adapter_model.safetensors"
185
+        ).read_bytes() == b"override audio bytes"
186
+        manifest = load_audio_snapshot_manifest(result.export_dir)
187
+        assert manifest.adapter_version == 1
188
+
189
+    def test_missing_adapter_override_refused(self, populated_store, tmp_path: Path) -> None:
190
+        with pytest.raises(ExportError, match="adapter_path_override .* does not exist"):
191
+            run_audio_snapshot_export(
192
+                populated_store,
193
+                _audio_spec(),
194
+                adapter_path_override=tmp_path / "missing",
195
+            )
196
+
197
+    def test_processor_save_pretrained_writes_processor_artifact(self, populated_store) -> None:
198
+        class _Processor:
199
+            def save_pretrained(self, out_dir: str) -> None:
200
+                Path(out_dir, "processor_config.json").write_text("{}", encoding="utf-8")
201
+
202
+        result = run_audio_snapshot_export(populated_store, _audio_spec(), processor=_Processor())
203
+
204
+        assert (result.processor_dir / "processor_config.json").exists()
205
+        manifest = load_audio_snapshot_manifest(result.export_dir)
206
+        paths = {entry.path for entry in manifest.artifacts}
207
+        assert "processor/processor_config.json" in paths
208
+
209
+    def test_noncallable_processor_save_is_ignored(self, populated_store) -> None:
210
+        class _Processor:
211
+            save_pretrained = "not-callable"
212
+
213
+        result = run_audio_snapshot_export(populated_store, _audio_spec(), processor=_Processor())
214
+
215
+        assert result.processor_dir.exists()
216
+        assert not any(result.processor_dir.iterdir())
217
+
145218
 
146219
 class TestManifestContent:
147220
     def test_export_target_is_hf_snapshot(self, populated_store) -> None:
@@ -213,6 +286,14 @@ class TestManifestLoadFailures:
213286
         with pytest.raises(ExportManifestError, match="cannot parse"):
214287
             load_audio_snapshot_manifest(tmp_path)
215288
 
289
+    def test_invalid_shape_raises(self, tmp_path: Path) -> None:
290
+        (tmp_path / SNAPSHOT_MANIFEST_FILENAME).write_text(
291
+            json.dumps({"created_by": "dlm-test"}),
292
+            encoding="utf-8",
293
+        )
294
+        with pytest.raises(ExportManifestError, match="invalid shape"):
295
+            load_audio_snapshot_manifest(tmp_path)
296
+
216297
 
217298
 class TestManifestModelDirect:
218299
     def test_frozen(self) -> None:
tests/unit/export/test_draft_registry.pymodified
20 lines changed — click to load
@@ -79,6 +79,20 @@ class TestValidatorRejectsMismatches:
7979
         with pytest.raises(ValueError, match="target_key 'qwen2.5-3b' not in BASE_MODELS"):
8080
             validate_registry(registry)
8181
 
82
+    def test_missing_draft_registry_key_raises(self, monkeypatch: pytest.MonkeyPatch) -> None:
83
+        import dlm.export.draft_registry as mod
84
+
85
+        bad_pair = DraftPair(
86
+            target_key="a",
87
+            draft_registry_key="missing",
88
+            upstream_ollama_tag="a:tiny",
89
+            notes="missing draft key",
90
+        )
91
+        monkeypatch.setattr(mod, "DRAFT_PAIRS", (bad_pair,))
92
+        registry = {"a": self._fake_spec()}
93
+        with pytest.raises(ValueError, match="draft_registry_key 'missing' not in BASE_MODELS"):
94
+            validate_registry(registry)
95
+
8296
     def test_mismatched_template_raises(self, monkeypatch: pytest.MonkeyPatch) -> None:
8397
         import dlm.export.draft_registry as mod
8498
 
tests/unit/export/test_embedding_sync.pymodified
51 lines changed — click to load
@@ -3,11 +3,14 @@
33
 from __future__ import annotations
44
 
55
 import json
6
+import sys
67
 from pathlib import Path
8
+from types import SimpleNamespace
79
 
810
 import numpy as np
911
 import pytest
1012
 
13
+import dlm.export.embedding_sync as embedding_sync
1114
 from dlm.export.embedding_sync import assert_embedding_rows_match
1215
 from dlm.export.errors import PreflightError
1316
 from dlm.export.gguf_tensors import GGML_TYPE_F16
@@ -435,6 +438,37 @@ class TestRobustSkips:
435438
 
436439
         assert _added_special_token_ids(adapter) == [3]
437440
 
441
+    def test_non_dict_added_tokens_decoder_returns_empty(self, tmp_path: Path) -> None:
442
+        adapter = tmp_path / "adapter"
443
+        adapter.mkdir()
444
+        (adapter / "tokenizer_config.json").write_text(
445
+            json.dumps(
446
+                {
447
+                    "vocab_size": 5,
448
+                    "chat_template": "x",
449
+                    "added_tokens_decoder": ["not", "a", "dict"],
450
+                }
451
+            )
452
+        )
453
+        assert embedding_sync._added_special_token_ids(adapter) == []
454
+
455
+    def test_safe_open_oserror_raises_preflight(
456
+        self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
457
+    ) -> None:
458
+        adapter = tmp_path / "adapter"
459
+        adapter.mkdir()
460
+        (adapter / "adapter_model.safetensors").write_bytes(b"placeholder")
461
+
462
+        def _boom(*_args: object, **_kwargs: object) -> object:
463
+            raise OSError("broken safetensors")
464
+
465
+        monkeypatch.setitem(sys.modules, "safetensors", SimpleNamespace(safe_open=_boom))
466
+        with pytest.raises(PreflightError, match="cannot read adapter safetensors"):
467
+            embedding_sync._load_adapter_safetensors(adapter)
468
+
469
+    def test_row_list_returns_empty_for_rank1_tensor(self) -> None:
470
+        assert embedding_sync._as_row_list(np.asarray([1.0, 2.0], dtype=np.float16)) == []
471
+
438472
 
439473
 class TestBoundsChecks:
440474
     def test_added_token_id_out_of_range_raises(self, tmp_path: Path) -> None:
tests/unit/export/test_gate_fallback_resolve.pymodified
72 lines changed — click to load
@@ -3,13 +3,14 @@
33
 from __future__ import annotations
44
 
55
 import json
6
+from dataclasses import replace
67
 from pathlib import Path
78
 from types import MappingProxyType
89
 
910
 from dlm.doc.parser import ParsedDlm
1011
 from dlm.doc.schema import AdapterConfig, DlmFrontmatter, GateConfig, TrainingConfig
1112
 from dlm.doc.sections import Section, SectionType
12
-from dlm.export.gate_fallback import resolve_gate_mix
13
+from dlm.export.gate_fallback import resolve_and_announce, resolve_gate_mix
1314
 from dlm.metrics.events import GateEvent, RunStart
1415
 from dlm.metrics.recorder import MetricsRecorder
1516
 from dlm.store.paths import StorePath
@@ -73,6 +74,23 @@ def test_no_gate_config_returns_none(tmp_path: Path) -> None:
7374
     assert resolve_gate_mix(store, _parsed()) is None
7475
 
7576
 
77
+def test_single_adapter_returns_none(tmp_path: Path) -> None:
78
+    store = StorePath(root=tmp_path)
79
+    store.ensure_layout()
80
+    parsed = _parsed(gate_enabled=False, adapters=("solo",))
81
+    single_adapter_training = parsed.frontmatter.training.model_copy(
82
+        update={"gate": GateConfig(enabled=True)}
83
+    )
84
+    single_adapter_frontmatter = parsed.frontmatter.model_copy(
85
+        update={"training": single_adapter_training}
86
+    )
87
+    assert resolve_gate_mix(store, replace(parsed, frontmatter=single_adapter_frontmatter)) is None
88
+
89
+
90
+def test_non_store_or_non_parsed_returns_none() -> None:
91
+    assert resolve_gate_mix(object(), object()) is None
92
+
93
+
7694
 def test_uniform_mode_returns_uniform_mix(tmp_path: Path) -> None:
7795
     store = StorePath(root=tmp_path)
7896
     store.ensure_layout()
@@ -152,3 +170,34 @@ def test_preserves_declared_adapter_order(tmp_path: Path) -> None:
152170
     mix = resolve_gate_mix(store, _parsed(adapters=("zeta", "alpha")))
153171
     # Order must match the config's adapter_names tuple, not alphabetic.
154172
     assert mix == [("zeta", 0.4), ("alpha", 0.6)]
173
+
174
+
175
+def test_resolve_and_announce_no_substitution(tmp_path: Path) -> None:
176
+    store = StorePath(root=tmp_path)
177
+    store.ensure_layout()
178
+
179
+    resolution = resolve_and_announce(store, _parsed(gate_enabled=False))
180
+
181
+    assert resolution.entries is None
182
+    assert resolution.banner_lines == []
183
+
184
+
185
+def test_resolve_and_announce_substitution_banner(tmp_path: Path) -> None:
186
+    store = StorePath(root=tmp_path)
187
+    store.ensure_layout()
188
+    _write_gate_config(
189
+        store,
190
+        GateMetadata(
191
+            input_dim=576,
192
+            hidden_proj_dim=64,
193
+            adapter_names=("a", "b"),
194
+            mode="uniform",
195
+        ),
196
+    )
197
+
198
+    resolution = resolve_and_announce(store, _parsed())
199
+
200
+    assert resolution.entries == [("a", 0.5), ("b", 0.5)]
201
+    assert resolution.banner_lines == [
202
+        "[dim]export: substituting learned gate weights for --adapter-mix (gate_mode=static).[/dim]"
203
+    ]
tests/unit/export/test_gguf_io.pyadded
33 lines changed — click to load
@@ -0,0 +1,33 @@
1
+"""Private GGUF IO helper coverage."""
2
+
3
+from __future__ import annotations
4
+
5
+import io
6
+import struct
7
+
8
+import pytest
9
+
10
+from dlm.export._gguf_io import _TYPE_ARRAY, _TYPE_STRING, _read_string, _read_u64, _skip_value
11
+
12
+
13
+def test_read_u64_short_read_raises() -> None:
14
+    with pytest.raises(struct.error, match="short read"):
15
+        _read_u64(io.BytesIO(b"\x01\x02"))
16
+
17
+
18
+def test_read_string_short_read_raises() -> None:
19
+    data = io.BytesIO(struct.pack("<Q", 4) + b"ab")
20
+
21
+    with pytest.raises(struct.error, match="short read in string"):
22
+        _read_string(data)
23
+
24
+
25
+def test_skip_value_string_array_huge_length_raises() -> None:
26
+    data = io.BytesIO(
27
+        struct.pack("<I", _TYPE_STRING)
28
+        + struct.pack("<Q", 1)
29
+        + struct.pack("<Q", (16 * 1024 * 1024) + 1)
30
+    )
31
+
32
+    with pytest.raises(struct.error, match="exceeds bound"):
33
+        _skip_value(data, _TYPE_ARRAY)
tests/unit/export/test_gguf_tensors.pymodified
75 lines changed — click to load
@@ -239,6 +239,33 @@ class TestLoadTensorIndex:
239239
         with pytest.raises(PreflightError, match="cannot parse GGUF"):
240240
             load_tensor_index(path)
241241
 
242
+    def test_short_tensor_name_read_refused(self, tmp_path: Path) -> None:
243
+        header = bytearray(b"GGUF")
244
+        header.extend(struct.pack("<I", 3))
245
+        header.extend(struct.pack("<Q", 1))  # tensor_count
246
+        header.extend(struct.pack("<Q", 1))  # kv_count
247
+        _write_kv_u32(header, "general.alignment", 32)
248
+        header.extend(struct.pack("<Q", 5))  # claims 5 bytes
249
+        header.extend(b"abc")  # only 3 bytes available
250
+        path = tmp_path / "short-name.gguf"
251
+        path.write_bytes(bytes(header))
252
+        with pytest.raises(PreflightError, match="cannot parse GGUF"):
253
+            load_tensor_index(path)
254
+
255
+    @pytest.mark.parametrize("n_dims", [0, 9])
256
+    def test_invalid_tensor_rank_refused(self, tmp_path: Path, n_dims: int) -> None:
257
+        header = bytearray(b"GGUF")
258
+        header.extend(struct.pack("<I", 3))
259
+        header.extend(struct.pack("<Q", 1))  # tensor_count
260
+        header.extend(struct.pack("<Q", 1))  # kv_count
261
+        _write_kv_u32(header, "general.alignment", 32)
262
+        _write_string(header, "token_embd.weight")
263
+        header.extend(struct.pack("<I", n_dims))
264
+        path = tmp_path / f"ndims-{n_dims}.gguf"
265
+        path.write_bytes(bytes(header))
266
+        with pytest.raises(PreflightError, match="cannot parse GGUF"):
267
+            load_tensor_index(path)
268
+
242269
 
243270
 class TestRowBytesErrors:
244271
     def _build_basic(self, tmp_path: Path) -> Path:
@@ -277,6 +304,42 @@ class TestRowBytesErrors:
277304
         with pytest.raises(PreflightError, match="block-quantized"):
278305
             index.row_bytes("token_embd.weight", 0)
279306
 
307
+    def test_rank_zero_tensor_refused(self, tmp_path: Path) -> None:
308
+        index = load_tensor_index(self._build_basic(tmp_path))
309
+        index = index.__class__(
310
+            path=index.path,
311
+            entries=(
312
+                TensorEntry(name="token_embd.weight", shape=(), dtype=GGML_TYPE_F16, offset=0),
313
+            ),
314
+            data_block_start=index.data_block_start,
315
+            alignment=index.alignment,
316
+        )
317
+        with pytest.raises(PreflightError, match="rank 0"):
318
+            index.row_bytes("token_embd.weight", 0)
319
+
320
+    def test_short_row_read_raises(self, tmp_path: Path) -> None:
321
+        path = self._build_basic(tmp_path)
322
+        index = load_tensor_index(path)
323
+        path.write_bytes(path.read_bytes()[:-1])
324
+        with pytest.raises(PreflightError, match="short read on row 1"):
325
+            index.row_bytes("token_embd.weight", 1)
326
+
327
+    def test_oserror_while_opening_tensor_raises(
328
+        self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
329
+    ) -> None:
330
+        path = self._build_basic(tmp_path)
331
+        index = load_tensor_index(path)
332
+        original_open = Path.open
333
+
334
+        def _boom(self: Path, *args: object, **kwargs: object) -> object:
335
+            if self == index.path:
336
+                raise OSError("nope")
337
+            return original_open(self, *args, **kwargs)
338
+
339
+        monkeypatch.setattr(Path, "open", _boom)
340
+        with pytest.raises(PreflightError, match="cannot read row 0"):
341
+            index.row_bytes("token_embd.weight", 0)
342
+
280343
 
281344
 class TestFindApi:
282345
     def test_find_returns_entry_or_none(self, tmp_path: Path) -> None:
tests/unit/export/test_imatrix.pymodified
108 lines changed — click to load
@@ -334,6 +334,36 @@ class TestResolveImatrix:
334334
             is None
335335
         )
336336
 
337
+    def test_non_string_sha_returns_none(self, tmp_path: Path) -> None:
338
+        export_dir = self._seed(tmp_path)
339
+        meta = json.loads((export_dir / "imatrix.meta.json").read_text())
340
+        meta["sha256"] = 123
341
+        (export_dir / "imatrix.meta.json").write_text(json.dumps(meta))
342
+        assert (
343
+            resolve_imatrix(
344
+                export_dir,
345
+                base_revision="r1",
346
+                corpus_sha256="c1",
347
+                chunks=DEFAULT_CHUNKS,
348
+            )
349
+            is None
350
+        )
351
+
352
+    def test_invalid_built_at_returns_none(self, tmp_path: Path) -> None:
353
+        export_dir = self._seed(tmp_path)
354
+        meta = json.loads((export_dir / "imatrix.meta.json").read_text())
355
+        meta["built_at"] = "not-a-datetime"
356
+        (export_dir / "imatrix.meta.json").write_text(json.dumps(meta))
357
+        assert (
358
+            resolve_imatrix(
359
+                export_dir,
360
+                base_revision="r1",
361
+                corpus_sha256="c1",
362
+                chunks=DEFAULT_CHUNKS,
363
+            )
364
+            is None
365
+        )
366
+
337367
 
338368
 # --- calibration_text_from_replay --------------------------------------------
339369
 
@@ -417,3 +447,72 @@ class TestCalibrationTextFromReplay:
417447
         # `max_chars` is the pre-joiner content budget; the `\n\n`
418448
         # separator between snapshots adds a small constant overhead.
419449
         assert len(text) <= 8_000 + 2 * 10  # 10 possible joiners
450
+
451
+    def test_empty_and_whitespace_snapshots_are_skipped(self, tmp_path: Path) -> None:
452
+        from datetime import UTC
453
+        from datetime import datetime as _dt
454
+
455
+        from dlm.replay.models import SectionSnapshot
456
+        from dlm.replay.store import ReplayStore
457
+
458
+        corpus = tmp_path / "corpus.zst"
459
+        idx = tmp_path / "index.json"
460
+        store = ReplayStore.at(corpus, idx)
461
+        snaps = [
462
+            SectionSnapshot(
463
+                section_id="0000000000000001",
464
+                section_type="prose",
465
+                content="",
466
+                first_seen_at=_dt(2026, 4, 19, tzinfo=UTC).replace(tzinfo=None),
467
+                last_seen_at=_dt(2026, 4, 19, tzinfo=UTC).replace(tzinfo=None),
468
+            ),
469
+            SectionSnapshot(
470
+                section_id="0000000000000002",
471
+                section_type="prose",
472
+                content="   \n\t  ",
473
+                first_seen_at=_dt(2026, 4, 19, tzinfo=UTC).replace(tzinfo=None),
474
+                last_seen_at=_dt(2026, 4, 19, tzinfo=UTC).replace(tzinfo=None),
475
+            ),
476
+            SectionSnapshot(
477
+                section_id="0000000000000003",
478
+                section_type="prose",
479
+                content="real calibration content",
480
+                first_seen_at=_dt(2026, 4, 19, tzinfo=UTC).replace(tzinfo=None),
481
+                last_seen_at=_dt(2026, 4, 19, tzinfo=UTC).replace(tzinfo=None),
482
+            ),
483
+        ]
484
+        store.append_many(snaps)
485
+
486
+        text, _sha = calibration_text_from_replay(corpus_path=corpus, index_path=idx)
487
+        assert text == "real calibration content"
488
+
489
+    def test_truncation_can_clip_with_zero_remaining_budget(self, tmp_path: Path) -> None:
490
+        from datetime import UTC
491
+        from datetime import datetime as _dt
492
+
493
+        from dlm.replay.models import SectionSnapshot
494
+        from dlm.replay.store import ReplayStore
495
+
496
+        corpus = tmp_path / "corpus.zst"
497
+        idx = tmp_path / "index.json"
498
+        store = ReplayStore.at(corpus, idx)
499
+        snaps = [
500
+            SectionSnapshot(
501
+                section_id="0000000000000001",
502
+                section_type="prose",
503
+                content="abcd",
504
+                first_seen_at=_dt(2026, 4, 19, tzinfo=UTC).replace(tzinfo=None),
505
+                last_seen_at=_dt(2026, 4, 19, tzinfo=UTC).replace(tzinfo=None),
506
+            ),
507
+            SectionSnapshot(
508
+                section_id="0000000000000002",
509
+                section_type="prose",
510
+                content="efgh",
511
+                first_seen_at=_dt(2026, 4, 19, tzinfo=UTC).replace(tzinfo=None),
512
+                last_seen_at=_dt(2026, 4, 19, tzinfo=UTC).replace(tzinfo=None),
513
+            ),
514
+        ]
515
+        store.append_many(snaps)
516
+
517
+        text, _sha = calibration_text_from_replay(corpus_path=corpus, index_path=idx, max_chars=4)
518
+        assert text == "abcd"
tests/unit/export/test_merge.pyadded
21 lines changed — click to load
@@ -0,0 +1,21 @@
1
+"""Pure merge helper coverage."""
2
+
3
+from __future__ import annotations
4
+
5
+import pytest
6
+
7
+from dlm.export.errors import UnsafeMergeError
8
+from dlm.export.merge import check_merge_safety
9
+from dlm.export.plan import ExportPlan
10
+
11
+
12
+def test_check_merge_safety_delegates_to_plan() -> None:
13
+    check_merge_safety(ExportPlan(merged=False), was_qlora=True)
14
+
15
+
16
+def test_check_merge_safety_refuses_unsafe_qlora_merge() -> None:
17
+    with pytest.raises(UnsafeMergeError, match="QLoRA"):
18
+        check_merge_safety(
19
+            ExportPlan(merged=True, dequantize_confirmed=False),
20
+            was_qlora=True,
21
+        )
tests/unit/export/test_precision_safety.pymodified
10 lines changed — click to load
@@ -54,6 +54,10 @@ class TestWasTrainedWithQlora:
5454
         with pytest.raises(PreflightError, match="training_run_json"):
5555
             was_trained_with_qlora(tmp_path, strict_training_run=True)
5656
 
57
+    def test_malformed_pinned_versions_falls_back_to_false(self, tmp_path: Path) -> None:
58
+        (tmp_path / "pinned_versions.json").write_text("not json", encoding="utf-8")
59
+        assert was_trained_with_qlora(tmp_path) is False
60
+
5761
 
5862
 class TestResolvePrecisionSafety:
5963
     def test_unmerged_export_is_safe(self, tmp_path: Path) -> None:
tests/unit/export/test_preflight.pymodified
73 lines changed — click to load
@@ -4,6 +4,7 @@ from __future__ import annotations
44
 
55
 import json
66
 from pathlib import Path
7
+from types import SimpleNamespace
78
 
89
 import pytest
910
 
@@ -12,7 +13,9 @@ from dlm.export.errors import PreflightError
1213
 from dlm.export.preflight import (
1314
     check_adapter_config,
1415
     check_chat_template,
16
+    check_pretokenizer_fingerprint,
1517
     check_tokenizer_vocab,
18
+    check_vl_target_modules_lm_only,
1619
     check_was_adapter_qlora,
1720
 )
1821
 
@@ -90,6 +93,14 @@ class TestTokenizerVocab:
9093
         with pytest.raises(PreflightError, match="cannot determine vocab"):
9194
             check_tokenizer_vocab(tmp_path)
9295
 
96
+    def test_malformed_tokenizer_json_raises(self, tmp_path: Path) -> None:
97
+        (tmp_path / "tokenizer_config.json").write_text(
98
+            json.dumps({"chat_template": "{{messages}}"})
99
+        )
100
+        (tmp_path / "tokenizer.json").write_text("not json {{{")
101
+        with pytest.raises(PreflightError, match="cannot parse"):
102
+            check_tokenizer_vocab(tmp_path)
103
+
93104
 
94105
 class TestChatTemplate:
95106
     def test_present_ok(self, tmp_path: Path) -> None:
@@ -114,6 +125,11 @@ class TestChatTemplate:
114125
         with pytest.raises(PreflightError, match="missing"):
115126
             check_chat_template(tmp_path, required=True)
116127
 
128
+    def test_malformed_config_raises(self, tmp_path: Path) -> None:
129
+        (tmp_path / "tokenizer_config.json").write_text("not json {{{")
130
+        with pytest.raises(PreflightError, match="cannot parse"):
131
+            check_chat_template(tmp_path, required=True)
132
+
117133
 
118134
 class TestQloraFlag:
119135
     def test_missing_file_returns_false(self, tmp_path: Path) -> None:
@@ -137,3 +153,32 @@ class TestQloraFlag:
137153
         (tmp_path / "training_run.json").write_text("not json")
138154
         with pytest.raises(PreflightError, match="training_run_json"):
139155
             check_was_adapter_qlora(tmp_path)
156
+
157
+
158
+class TestPretokenizerFingerprint:
159
+    def test_failed_probe_raises(self, monkeypatch: pytest.MonkeyPatch) -> None:
160
+        monkeypatch.setattr(
161
+            "dlm.base_models.probes.probe_pretokenizer_hash",
162
+            lambda _spec: SimpleNamespace(skipped=False, passed=False, detail="mismatch"),
163
+        )
164
+
165
+        with pytest.raises(PreflightError, match="pre-tokenizer fingerprint mismatch"):
166
+            check_pretokenizer_fingerprint(_SPEC)
167
+
168
+
169
+class TestVlTargetModulesLmOnly:
170
+    def test_missing_config_is_noop(self, tmp_path: Path) -> None:
171
+        check_vl_target_modules_lm_only(tmp_path)
172
+
173
+    def test_malformed_config_is_noop(self, tmp_path: Path) -> None:
174
+        (tmp_path / "adapter_config.json").write_text("not json {{{")
175
+        check_vl_target_modules_lm_only(tmp_path)
176
+
177
+    def test_string_pattern_target_modules_is_noop(self, tmp_path: Path) -> None:
178
+        _write_adapter_config(tmp_path, target_modules=".*q_proj.*")
179
+        check_vl_target_modules_lm_only(tmp_path)
180
+
181
+    def test_vision_targets_raise(self, tmp_path: Path) -> None:
182
+        _write_adapter_config(tmp_path, target_modules=["q_proj", "vision_tower.block.0"])
183
+        with pytest.raises(PreflightError, match="vision-tower modules"):
184
+            check_vl_target_modules_lm_only(tmp_path)
tests/unit/export/test_runner.pymodified
301 lines changed — click to load
@@ -101,6 +101,21 @@ def _setup_store(tmp_path: Path, *, use_qlora: bool = False) -> tuple[Path, Any,
101101
     return cached_base, store, vendor
102102
 
103103
 
104
+def _setup_named_store(tmp_path: Path) -> tuple[Path, Any, Path]:
105
+    cached_base, store, vendor = _setup_store(tmp_path)
106
+    knowledge = store.adapter_version_for("knowledge", 2)
107
+    knowledge.mkdir(parents=True)
108
+    (knowledge / "adapter_config.json").write_text(
109
+        json.dumps({"base_model_name_or_path": _SPEC.hf_id, "peft_type": "LORA"})
110
+    )
111
+    (knowledge / "tokenizer_config.json").write_text(
112
+        json.dumps({"vocab_size": 32000, "chat_template": "{{messages}}"})
113
+    )
114
+    (knowledge / "training_run.json").write_text(json.dumps({"use_qlora": False}))
115
+    store.set_current_adapter_for("knowledge", knowledge)
116
+    return cached_base, store, vendor
117
+
118
+
104119
 def _relative_file_bytes(root: Path) -> dict[str, bytes]:
105120
     return {
106121
         str(path.relative_to(root)): path.read_bytes()
@@ -110,6 +125,11 @@ def _relative_file_bytes(root: Path) -> dict[str, bytes]:
110125
 
111126
 
112127
 class TestHappyPath:
128
+    def test_default_ollama_name_lowercases_dlm_id(self) -> None:
129
+        from dlm.export.runner import default_ollama_name
130
+
131
+        assert default_ollama_name("01ABCDEF", 7) == "dlm-01abcdef:v0007"
132
+
113133
     def test_unmerged_export_emits_base_and_adapter(self, tmp_path: Path) -> None:
114134
         cached_base, store, vendor = _setup_store(tmp_path)
115135
         plan = ExportPlan(quant="Q4_K_M", merged=False)
@@ -332,6 +352,36 @@ class TestMergeGate:
332352
         # No subprocess should have launched on the safety-gate path.
333353
         assert recorder.commands == []
334354
 
355
+    def test_merged_export_delegates_to_merge_path(
356
+        self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
357
+    ) -> None:
358
+        cached_base, store, vendor = _setup_store(tmp_path, use_qlora=False)
359
+        plan = ExportPlan(merged=True, dequantize_confirmed=True)
360
+        recorder = _SubprocessRecorder(store.export_quant_dir(plan.quant))
361
+        seen: list[dict[str, object]] = []
362
+
363
+        def _fake_merge_path(**kwargs: object) -> None:
364
+            seen.append(kwargs)
365
+
366
+        monkeypatch.setattr("dlm.export.runner._perform_merge_path", _fake_merge_path)
367
+
368
+        result = run_export(
369
+            store,
370
+            _SPEC,
371
+            plan,
372
+            cached_base_dir=cached_base,
373
+            subprocess_runner=recorder,
374
+            vendor_override=vendor,
375
+            skip_ollama=True,
376
+            vocab_checker=lambda _a, _g: None,
377
+            embedding_checker=lambda _a, _g: None,
378
+        )
379
+
380
+        assert result.merged is True
381
+        assert len(seen) == 1
382
+        assert seen[0]["adapter_path"] == store.resolve_current_adapter()
383
+        assert seen[0]["was_qlora"] is False
384
+
335385
 
336386
 class TestDefaultVocabCheck:
337387
     """Default path loads the adapter tokenizer-vocab and compares against the base GGUF."""
@@ -691,6 +741,54 @@ class TestMissingAdapter:
691741
                 subprocess_runner=lambda _cmd: None,
692742
             )
693743
 
744
+    def test_missing_adapter_override_raises(self, tmp_path: Path) -> None:
745
+        cached_base, store, vendor = _setup_store(tmp_path)
746
+
747
+        with pytest.raises(ExportError, match="adapter_path_override .* does not exist"):
748
+            run_export(
749
+                store,
750
+                _SPEC,
751
+                ExportPlan(),
752
+                cached_base_dir=cached_base,
753
+                subprocess_runner=lambda _cmd: None,
754
+                vendor_override=vendor,
755
+                adapter_path_override=tmp_path / "missing",
756
+            )
757
+
758
+    def test_missing_named_adapter_raises(self, tmp_path: Path) -> None:
759
+        cached_base, store, vendor = _setup_store(tmp_path)
760
+
761
+        with pytest.raises(ExportError, match="run `dlm train` before exporting for adapter"):
762
+            run_export(
763
+                store,
764
+                _SPEC,
765
+                ExportPlan(),
766
+                cached_base_dir=cached_base,
767
+                subprocess_runner=lambda _cmd: None,
768
+                vendor_override=vendor,
769
+                adapter_name="knowledge",
770
+            )
771
+
772
+    def test_named_adapter_export_uses_named_pointer(self, tmp_path: Path) -> None:
773
+        cached_base, store, vendor = _setup_named_store(tmp_path)
774
+        recorder = _SubprocessRecorder(store.export_quant_dir("Q4_K_M"))
775
+
776
+        result = run_export(
777
+            store,
778
+            _SPEC,
779
+            ExportPlan(),
780
+            cached_base_dir=cached_base,
781
+            subprocess_runner=recorder,
782
+            vendor_override=vendor,
783
+            skip_ollama=True,
784
+            vocab_checker=lambda _a, _g: None,
785
+            embedding_checker=lambda _a, _g: None,
786
+            adapter_name="knowledge",
787
+        )
788
+
789
+        assert result.export_dir == store.export_quant_dir("Q4_K_M")
790
+        assert len(recorder.commands) == 3
791
+
694792
 
695793
 class TestManifestAppend:
696794
     def test_exports_list_grows(self, tmp_path: Path) -> None:
@@ -751,3 +849,179 @@ class TestManifestAppend:
751849
         # Peer released → no export summary landed (we errored before save).
752850
         manifest = load_manifest(store.manifest)
753851
         assert len(manifest.exports) == 0
852
+
853
+
854
+class TestRunnerInternals:
855
+    def test_cached_base_missing_manifest_is_false(self, tmp_path: Path) -> None:
856
+        from dlm.export.runner import _cached_base_matches
857
+
858
+        export_dir = tmp_path / "exports" / "Q4_K_M"
859
+        export_dir.mkdir(parents=True)
860
+        base_gguf = export_dir / "base.Q4_K_M.gguf"
861
+        base_gguf.write_bytes(b"cached bytes")
862
+
863
+        assert _cached_base_matches(export_dir, base_gguf, "Q4_K_M") is False
864
+
865
+    def test_cached_base_quant_mismatch_is_false(self, tmp_path: Path) -> None:
866
+        from dlm.export.manifest import ExportManifest
867
+        from dlm.export.runner import _cached_base_matches
868
+
869
+        export_dir = tmp_path / "exports" / "Q4_K_M"
870
+        export_dir.mkdir(parents=True)
871
+        base_gguf = export_dir / "base.Q4_K_M.gguf"
872
+        base_gguf.write_bytes(b"cached bytes")
873
+        manifest = ExportManifest(
874
+            target="ollama",
875
+            quant="Q5_K_M",
876
+            created_at=datetime(2026, 4, 23, 12, 0, 0),
877
+            created_by="dlm-test",
878
+            base_model_hf_id="org/base",
879
+            base_model_revision="a" * 40,
880
+            adapter_version=1,
881
+            artifacts=[],
882
+        )
883
+        (export_dir / "export_manifest.json").write_text(
884
+            manifest.model_dump_json(indent=2) + "\n",
885
+            encoding="utf-8",
886
+        )
887
+
888
+        assert _cached_base_matches(export_dir, base_gguf, "Q4_K_M") is False
889
+
890
+    def test_cached_base_without_recorded_artifact_is_false(self, tmp_path: Path) -> None:
891
+        from dlm.export.manifest import ExportManifest, build_artifact
892
+        from dlm.export.runner import _cached_base_matches
893
+
894
+        export_dir = tmp_path / "exports" / "Q4_K_M"
895
+        export_dir.mkdir(parents=True)
896
+        base_gguf = export_dir / "base.Q4_K_M.gguf"
897
+        other = export_dir / "other.gguf"
898
+        base_gguf.write_bytes(b"cached bytes")
899
+        other.write_bytes(b"other bytes")
900
+        manifest = ExportManifest(
901
+            target="ollama",
902
+            quant="Q4_K_M",
903
+            created_at=datetime(2026, 4, 23, 12, 0, 0),
904
+            created_by="dlm-test",
905
+            base_model_hf_id="org/base",
906
+            base_model_revision="a" * 40,
907
+            adapter_version=1,
908
+            artifacts=[build_artifact(export_dir, other)],
909
+        )
910
+        (export_dir / "export_manifest.json").write_text(
911
+            manifest.model_dump_json(indent=2) + "\n",
912
+            encoding="utf-8",
913
+        )
914
+
915
+        assert _cached_base_matches(export_dir, base_gguf, "Q4_K_M") is False
916
+
917
+    def test_cached_imatrix_without_existing_file_returns_none(self, tmp_path: Path) -> None:
918
+        from dlm.export.runner import _resolve_or_build_imatrix
919
+
920
+        cached_base, store, _vendor = _setup_store(tmp_path)
921
+        fp16 = tmp_path / "base.fp16.gguf"
922
+        fp16.write_bytes(b"fp16")
923
+
924
+        assert (
925
+            _resolve_or_build_imatrix(
926
+                export_dir=tmp_path,
927
+                fp16_path=fp16,
928
+                plan=ExportPlan(quant="Q4_K_M", imatrix="cached"),
929
+                run=lambda _cmd: None,
930
+                vendor_override=None,
931
+                spec=_SPEC,
932
+                store=store,
933
+            )
934
+            is None
935
+        )
936
+
937
+    def test_auto_imatrix_cache_hit_logs_and_returns_path(
938
+        self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
939
+    ) -> None:
940
+        from types import SimpleNamespace
941
+
942
+        from dlm.export.runner import _resolve_or_build_imatrix
943
+
944
+        cached_base, store, _vendor = _setup_store(tmp_path)
945
+        fp16 = tmp_path / "base.fp16.gguf"
946
+        fp16.write_bytes(b"fp16")
947
+        imatrix = tmp_path / "imatrix.gguf"
948
+        imatrix.write_bytes(b"imatrix")
949
+
950
+        monkeypatch.setattr(
951
+            "dlm.export.imatrix.calibration_text_from_replay",
952
+            lambda **_kwargs: ("calibration text", "abc123"),
953
+        )
954
+        monkeypatch.setattr(
955
+            "dlm.export.imatrix.resolve_imatrix",
956
+            lambda *_args, **_kwargs: SimpleNamespace(path=imatrix, sha256="abcdef123456"),
957
+        )
958
+        caplog.set_level(logging.INFO, logger="dlm.export.runner")
959
+
960
+        resolved = _resolve_or_build_imatrix(
961
+            export_dir=tmp_path,
962
+            fp16_path=fp16,
963
+            plan=ExportPlan(quant="Q4_K_M", imatrix="auto"),
964
+            run=lambda _cmd: None,
965
+            vendor_override=None,
966
+            spec=_SPEC,
967
+            store=store,
968
+        )
969
+
970
+        assert resolved == imatrix
971
+        assert "imatrix: cache hit (" in caplog.text
972
+
973
+    def test_run_ollama_stage_records_detected_version(
974
+        self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
975
+    ) -> None:
976
+        from dlm.export.runner import _run_ollama_stage
977
+
978
+        cached_base, store, _vendor = _setup_store(tmp_path)
979
+        export_dir = store.export_quant_dir("Q4_K_M")
980
+        export_dir.mkdir(parents=True, exist_ok=True)
981
+        base_gguf = export_dir / "base.Q4_K_M.gguf"
982
+        base_gguf.write_bytes(b"base")
983
+        adapter = store.resolve_current_adapter()
984
+        assert adapter is not None
985
+
986
+        monkeypatch.setattr("dlm.export.ollama.check_ollama_version", lambda: (1, 2, 3))
987
+        monkeypatch.setattr("dlm.export.draft_registry.resolve_draft", lambda *args, **kwargs: None)
988
+        monkeypatch.setattr(
989
+            "dlm.export.ollama.render_modelfile",
990
+            lambda _ctx: "FROM ./base.Q4_K_M.gguf\n",
991
+        )
992
+
993
+        seen: list[str] = []
994
+
995
+        def _create(*, name: str, modelfile_path: Path, cwd: Path) -> str:
996
+            seen.append(name)
997
+            assert modelfile_path.exists()
998
+            assert cwd == export_dir
999
+            return "created"
1000
+
1001
+        monkeypatch.setattr("dlm.export.ollama.ollama_create", _create)
1002
+        monkeypatch.setattr("dlm.export.ollama.ollama_run", lambda **_kwargs: "unused")
1003
+
1004
+        modelfile_path, name, ver_str, smoke_first_line = _run_ollama_stage(
1005
+            store=store,
1006
+            spec=_SPEC,
1007
+            plan=ExportPlan(quant="Q4_K_M"),
1008
+            adapter_path=adapter,
1009
+            export_dir=export_dir,
1010
+            base_gguf_path=base_gguf,
1011
+            adapter_version=1,
1012
+            system_prompt=None,
1013
+            source_dlm_path=None,
1014
+            skip_smoke=True,
1015
+            ollama_create_runner=None,
1016
+            ollama_run_runner=None,
1017
+            training_sequence_len=None,
1018
+            override_temperature=None,
1019
+            override_top_p=None,
1020
+            draft_override=None,
1021
+            draft_disabled=False,
1022
+        )
1023
+
1024
+        assert modelfile_path.exists()
1025
+        assert name == seen[0]
1026
+        assert ver_str == "1.2.3"
1027
+        assert smoke_first_line is None
tests/unit/export/test_smoke.pymodified
33 lines changed — click to load
@@ -330,6 +330,26 @@ class TestChatCompletion:
330330
 
331331
 
332332
 class TestSmokeHelpers:
333
+    def test_reserve_local_port_returns_loopback_port(
334
+        self, monkeypatch: pytest.MonkeyPatch
335
+    ) -> None:
336
+        class _FakeSocket:
337
+            def __enter__(self) -> _FakeSocket:
338
+                return self
339
+
340
+            def __exit__(self, exc_type: object, exc: object, tb: object) -> None:
341
+                return None
342
+
343
+            def bind(self, address: tuple[str, int]) -> None:
344
+                assert address == ("127.0.0.1", 0)
345
+
346
+            def getsockname(self) -> tuple[str, int]:
347
+                return ("127.0.0.1", 43123)
348
+
349
+        monkeypatch.setattr(smoke_mod.socket, "socket", lambda *_args, **_kwargs: _FakeSocket())
350
+
351
+        assert smoke_mod.reserve_local_port() == 43123
352
+
333353
     def test_normalize_message_content(self) -> None:
334354
         assert smoke_mod._normalize_message_content("  hello  ") == "hello"
335355
         assert (
@@ -338,6 +358,7 @@ class TestSmokeHelpers:
338358
             )
339359
             == "first\nsecond"
340360
         )
361
+        assert smoke_mod._normalize_message_content([{"text": "first"}, "skip-me"]) == "first"
341362
         assert smoke_mod._normalize_message_content([{"text": "   "}]) is None
342363
         assert smoke_mod._normalize_message_content(3) is None
343364
 
tests/unit/export/test_vendoring.pymodified
49 lines changed — click to load
@@ -2,6 +2,7 @@
22
 
33
 from __future__ import annotations
44
 
5
+import errno
56
 from pathlib import Path
67
 
78
 import pytest
@@ -47,6 +48,19 @@ class TestLlamaCppRoot:
4748
         root = _populate_vendor(tmp_path / "llama.cpp")
4849
         assert llama_cpp_root(override=root) == root
4950
 
51
+    def test_enumeration_failure_raises(
52
+        self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
53
+    ) -> None:
54
+        root = _populate_vendor(tmp_path / "llama.cpp")
55
+
56
+        def _raise_iterdir() -> object:
57
+            raise OSError(errno.EIO, "boom")
58
+
59
+        monkeypatch.setattr(Path, "iterdir", lambda self: _raise_iterdir())
60
+
61
+        with pytest.raises(VendoringError, match="cannot enumerate"):
62
+            llama_cpp_root(override=root)
63
+
5064
 
5165
 class TestScriptResolvers:
5266
     def test_convert_hf_resolves(self, tmp_path: Path) -> None:
@@ -100,6 +114,23 @@ class TestLlamaBinaries:
100114
         with pytest.raises(VendoringError, match="llama-server"):
101115
             llama_server_bin(override=root)
102116
 
117
+    def test_path_lookup_returns_binary_when_vendor_missing(
118
+        self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
119
+    ) -> None:
120
+        monkeypatch.setenv("PATH", str(tmp_path))
121
+        fake = tmp_path / "llama-quantize"
122
+        fake.write_text("#!/bin/sh\n", encoding="utf-8")
123
+        fake.chmod(0o755)
124
+        monkeypatch.setattr(
125
+            "shutil.which", lambda name: str(fake) if name == "llama-quantize" else None
126
+        )
127
+
128
+        path = llama_quantize_bin(
129
+            override=_populate_vendor(tmp_path / "llama.cpp", with_binary=False)
130
+        )
131
+
132
+        assert path == fake
133
+
103134
     def test_dlm_llama_cpp_build_env_preferred(
104135
         self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
105136
     ) -> None:
tests/unit/export/test_vl_gguf.pymodified
131 lines changed — click to load
@@ -19,9 +19,10 @@ from typing import Any
1919
 
2020
 import pytest
2121
 
22
+import dlm.export.vl_gguf as vl_gguf
2223
 from dlm.base_models.schema import BaseModelSpec, VlPreprocessorPlan
2324
 from dlm.export.arch_probe import ArchProbeResult, SupportLevel
24
-from dlm.export.errors import VlGgufUnsupportedError
25
+from dlm.export.errors import ExportError, VlGgufUnsupportedError
2526
 from dlm.export.plan import ExportPlan
2627
 from dlm.export.vl_gguf import VlGgufResult, run_vl_gguf_export
2728
 from dlm.store.paths import for_dlm
@@ -132,6 +133,38 @@ def _populate_adapter(store: Any, version: int = 1) -> Path:
132133
     return adapter
133134
 
134135
 
136
+def _populate_named_adapter(store: Any, name: str, version: int = 1) -> Path:
137
+    """Write a minimally-valid named adapter checkpoint under `adapter/<name>/`."""
138
+    store.ensure_layout()
139
+    adapter: Path = store.adapter_version_for(name, version)
140
+    adapter.mkdir(parents=True, exist_ok=True)
141
+    (adapter / "adapter_config.json").write_text(
142
+        json.dumps(
143
+            {
144
+                "base_model_name_or_path": "Qwen/Qwen2-VL-2B-Instruct",
145
+                "target_modules": ["q_proj", "v_proj"],
146
+                "r": 16,
147
+            }
148
+        ),
149
+        encoding="utf-8",
150
+    )
151
+    (adapter / "tokenizer_config.json").write_text(
152
+        json.dumps(
153
+            {
154
+                "vocab_size": 151643,
155
+                "chat_template": "{{ 'hi' }}",
156
+            }
157
+        ),
158
+        encoding="utf-8",
159
+    )
160
+    (adapter / "training_run.json").write_text(
161
+        json.dumps({"use_qlora": False}),
162
+        encoding="utf-8",
163
+    )
164
+    store.set_current_adapter_for(name, adapter)
165
+    return adapter
166
+
167
+
135168
 class TestRefusals:
136169
     """Covers `_assert_supported` — the three preconditions + adapter gate."""
137170
 
@@ -302,3 +335,82 @@ class TestHappyPath:
302335
         paths = {a["path"] for a in manifest["artifacts"]}
303336
         assert "base.Q4_K_M.gguf" in paths
304337
         assert "Modelfile" in paths
338
+
339
+    def test_named_adapter_export_uses_named_current_pointer(self, tmp_path: Path) -> None:
340
+        store = for_dlm(_VALID_ULID, home=tmp_path)
341
+        flat = _populate_adapter(store, version=1)
342
+        named = _populate_named_adapter(store, "knowledge", version=2)
343
+        cached_base = tmp_path / "base-cache"
344
+        cached_base.mkdir()
345
+        llama_cpp_root = _fixture_llama_cpp_root(tmp_path)
346
+
347
+        merge_calls: list[tuple[Path, Path, Path]] = []
348
+
349
+        def _recorder(args: Any) -> None:
350
+            if args and args[0].endswith("llama-quantize"):
351
+                Path(args[2]).write_bytes(b"stub-gguf-bytes")
352
+
353
+        def _merge(adapter: Path, out_dir: Path, *, cached_base_dir: Path) -> None:
354
+            merge_calls.append((adapter, out_dir, cached_base_dir))
355
+            out_dir.mkdir(parents=True, exist_ok=True)
356
+
357
+        result = run_vl_gguf_export(
358
+            store,
359
+            _qwen2vl_spec(),
360
+            _merged_plan(),
361
+            verdict=_supported_verdict(),
362
+            cached_base_dir=cached_base,
363
+            adapter_name="knowledge",
364
+            subprocess_runner=_recorder,
365
+            merge_runner=_merge,
366
+            llama_cpp_root_override=llama_cpp_root,
367
+        )
368
+
369
+        assert len(merge_calls) == 1
370
+        assert merge_calls[0][0] == named
371
+        assert merge_calls[0][0] != flat
372
+        manifest = json.loads(result.manifest_path.read_text(encoding="utf-8"))
373
+        assert manifest["adapter_version"] == 2
374
+
375
+    def test_missing_quantize_output_raises(self, tmp_path: Path) -> None:
376
+        store = for_dlm(_VALID_ULID, home=tmp_path)
377
+        _populate_adapter(store)
378
+        cached_base = tmp_path / "base-cache"
379
+        cached_base.mkdir()
380
+        llama_cpp_root = _fixture_llama_cpp_root(tmp_path)
381
+
382
+        def _recorder(_args: Any) -> None:
383
+            return None
384
+
385
+        def _merge(_adapter: Path, out_dir: Path, *, cached_base_dir: Path) -> None:
386
+            out_dir.mkdir(parents=True, exist_ok=True)
387
+
388
+        with pytest.raises(ExportError, match="expected .*base.Q4_K_M.gguf"):
389
+            run_vl_gguf_export(
390
+                store,
391
+                _qwen2vl_spec(),
392
+                _merged_plan(),
393
+                verdict=_supported_verdict(),
394
+                cached_base_dir=cached_base,
395
+                subprocess_runner=_recorder,
396
+                merge_runner=_merge,
397
+                llama_cpp_root_override=llama_cpp_root,
398
+            )
399
+
400
+
401
+class TestHelpers:
402
+    def test_version_parser_falls_back_to_one_for_non_version_dir(self, tmp_path: Path) -> None:
403
+        assert vl_gguf._version_from_dir_name(tmp_path / "merged-adapter") == 1
404
+
405
+    def test_default_runner_delegates_to_run_checked(self, monkeypatch: pytest.MonkeyPatch) -> None:
406
+        recorded: dict[str, object] = {}
407
+
408
+        def _fake_run_checked(args: list[str], *, timeout: int) -> object:
409
+            recorded["args"] = args
410
+            recorded["timeout"] = timeout
411
+            return "ok"
412
+
413
+        monkeypatch.setattr(vl_gguf, "run_checked", _fake_run_checked)
414
+        out = vl_gguf._default_runner(("python", "tool.py"))
415
+        assert out == "ok"
416
+        assert recorded == {"args": ["python", "tool.py"], "timeout": 60 * 60}
tests/unit/export/test_vl_snapshot.pymodified
101 lines changed — click to load
@@ -12,6 +12,7 @@ Covers:
1212
 
1313
 from __future__ import annotations
1414
 
15
+import json
1516
 from pathlib import Path
1617
 
1718
 import pytest
@@ -96,6 +97,12 @@ class TestRefusals:
9697
         with pytest.raises(ExportError, match="only vision-language bases"):
9798
             run_vl_snapshot_export(populated_store, _text_spec())
9899
 
100
+    def test_missing_vl_preprocessor_plan_refused(self, populated_store) -> None:
101
+        spec = _vl_spec()
102
+        object.__setattr__(spec, "vl_preprocessor_plan", None)
103
+        with pytest.raises(ExportError, match="no vl_preprocessor_plan"):
104
+            run_vl_snapshot_export(populated_store, spec)
105
+
99106
     def test_missing_adapter_refused(self, tmp_path: Path) -> None:
100107
         store = for_dlm(_VALID_ULID, home=tmp_path)
101108
         store.ensure_layout()
@@ -132,6 +139,68 @@ class TestSnapshotLayout:
132139
         result = run_vl_snapshot_export(populated_store, _vl_spec())
133140
         assert (result.adapter_dir / "adapter_model.safetensors").read_bytes() == b"new bytes"
134141
 
142
+    def test_named_adapter_export_uses_named_current_pointer(self, populated_store) -> None:
143
+        named = populated_store.adapter_version_for("knowledge", 7)
144
+        named.mkdir(parents=True, exist_ok=True)
145
+        (named / "adapter_config.json").write_text('{"r": 32}', encoding="utf-8")
146
+        (named / "adapter_model.safetensors").write_bytes(b"named bytes")
147
+        populated_store.set_current_adapter_for("knowledge", named)
148
+
149
+        result = run_vl_snapshot_export(
150
+            populated_store,
151
+            _vl_spec(),
152
+            adapter_name="knowledge",
153
+        )
154
+
155
+        assert (result.adapter_dir / "adapter_model.safetensors").read_bytes() == b"named bytes"
156
+        manifest = load_vl_snapshot_manifest(result.export_dir)
157
+        assert manifest.adapter_version == 7
158
+        assert manifest.adapter_name == "knowledge"
159
+
160
+    def test_adapter_override_uses_provided_dir(self, populated_store, tmp_path: Path) -> None:
161
+        override = tmp_path / "merged-adapter"
162
+        override.mkdir()
163
+        (override / "adapter_model.safetensors").write_bytes(b"override bytes")
164
+
165
+        result = run_vl_snapshot_export(
166
+            populated_store,
167
+            _vl_spec(),
168
+            adapter_path_override=override,
169
+        )
170
+
171
+        assert (result.adapter_dir / "adapter_model.safetensors").read_bytes() == b"override bytes"
172
+        manifest = load_vl_snapshot_manifest(result.export_dir)
173
+        assert manifest.adapter_version == 1
174
+
175
+    def test_missing_adapter_override_refused(self, populated_store, tmp_path: Path) -> None:
176
+        with pytest.raises(ExportError, match="adapter_path_override .* does not exist"):
177
+            run_vl_snapshot_export(
178
+                populated_store,
179
+                _vl_spec(),
180
+                adapter_path_override=tmp_path / "missing",
181
+            )
182
+
183
+    def test_processor_save_pretrained_writes_processor_artifact(self, populated_store) -> None:
184
+        class _Processor:
185
+            def save_pretrained(self, out_dir: str) -> None:
186
+                Path(out_dir, "processor_config.json").write_text("{}", encoding="utf-8")
187
+
188
+        result = run_vl_snapshot_export(populated_store, _vl_spec(), processor=_Processor())
189
+
190
+        assert (result.processor_dir / "processor_config.json").exists()
191
+        manifest = load_vl_snapshot_manifest(result.export_dir)
192
+        paths = {entry.path for entry in manifest.artifacts}
193
+        assert "processor/processor_config.json" in paths
194
+
195
+    def test_noncallable_processor_save_is_ignored(self, populated_store) -> None:
196
+        class _Processor:
197
+            save_pretrained = "not-callable"
198
+
199
+        result = run_vl_snapshot_export(populated_store, _vl_spec(), processor=_Processor())
200
+
201
+        assert result.processor_dir.exists()
202
+        assert not any(result.processor_dir.iterdir())
203
+
135204
 
136205
 class TestManifestContent:
137206
     def test_export_target_is_hf_snapshot(self, populated_store) -> None:
@@ -203,6 +272,14 @@ class TestManifestLoadFailures:
203272
         with pytest.raises(ExportManifestError, match="cannot parse"):
204273
             load_vl_snapshot_manifest(tmp_path)
205274
 
275
+    def test_invalid_shape_raises(self, tmp_path: Path) -> None:
276
+        (tmp_path / SNAPSHOT_MANIFEST_FILENAME).write_text(
277
+            json.dumps({"created_by": "dlm-test"}),
278
+            encoding="utf-8",
279
+        )
280
+        with pytest.raises(ExportManifestError, match="invalid shape"):
281
+            load_vl_snapshot_manifest(tmp_path)
282
+
206283
 
207284
 class TestManifestModelDirect:
208285
     def test_frozen(self) -> None:
tests/unit/hardware/test_capabilities.pymodified
88 lines changed — click to load
@@ -2,10 +2,13 @@
22
 
33
 from __future__ import annotations
44
 
5
+from types import SimpleNamespace
6
+from unittest.mock import patch
7
+
58
 import pytest
69
 
710
 from dlm.hardware.backend import Backend
8
-from dlm.hardware.capabilities import probe
11
+from dlm.hardware.capabilities import _accelerate_version, _rocm_arch_supports_bf16, probe
912
 from tests.fixtures.hardware_mocks import force_cpu, force_cuda, force_mps, force_rocm
1013
 
1114
 
@@ -30,6 +33,26 @@ class TestProbeCuda:
3033
         # flash_attn gated on SM>=8.0 regardless of package availability
3134
         assert caps.has_flash_attention is False
3235
 
36
+    def test_cuda_sm_probe_failure_yields_unknown_sm(self) -> None:
37
+        with force_cuda():
38
+            with patch("torch.cuda.get_device_capability", side_effect=RuntimeError("boom")):
39
+                caps = probe()
40
+        assert caps.sm is None
41
+
42
+    def test_cuda_vram_probe_failure_yields_unknown_vram(self) -> None:
43
+        with force_cuda():
44
+            with patch("torch.cuda.mem_get_info", side_effect=RuntimeError("boom")):
45
+                caps = probe()
46
+        assert caps.vram_gb is None
47
+
48
+    def test_cuda_flash_attention_true_when_package_and_sm_supported(self) -> None:
49
+        with (
50
+            patch("dlm.hardware.capabilities._module_available", lambda name: name == "flash_attn"),
51
+            force_cuda(sm=(8, 0)),
52
+        ):
53
+            caps = probe()
54
+        assert caps.has_flash_attention is True
55
+
3356
 
3457
 class TestProbeRocm:
3558
     def test_rocm_reports_hip_version(self) -> None:
@@ -41,6 +64,20 @@ class TestProbeRocm:
4164
         assert caps.determinism_class == "best-effort"
4265
         assert caps.has_flash_attention is False
4366
 
67
+    def test_rocm_arch_probe_failure_yields_unknown_arch(self) -> None:
68
+        with force_rocm():
69
+            with patch("torch.cuda.get_device_properties", side_effect=RuntimeError("boom")):
70
+                caps = probe()
71
+        assert caps.rocm_arch is None
72
+
73
+    def test_rocm_arch_probe_missing_name_yields_unknown_arch(self) -> None:
74
+        with force_rocm():
75
+            with patch(
76
+                "torch.cuda.get_device_properties", return_value=SimpleNamespace(name="AMD")
77
+            ):
78
+                caps = probe()
79
+        assert caps.rocm_arch is None
80
+
4481
 
4582
 class TestProbeMps:
4683
     def test_mps_caps(self) -> None:
@@ -54,6 +91,14 @@ class TestProbeMps:
5491
         assert caps.determinism_class == "best-effort"
5592
         assert caps.has_flash_attention is False
5693
 
94
+    def test_mps_never_reports_flash_attention(self) -> None:
95
+        with (
96
+            patch("dlm.hardware.capabilities._module_available", lambda name: name == "flash_attn"),
97
+            force_mps(),
98
+        ):
99
+            caps = probe()
100
+        assert caps.has_flash_attention is False
101
+
57102
 
58103
 class TestMlxAvailability:
59104
     def test_non_mps_never_reports_mlx(self) -> None:
@@ -133,3 +178,14 @@ class TestTelemetryPosture:
133178
             caps = probe()
134179
         assert caps.telemetry_posture["HF_HUB_DISABLE_TELEMETRY"] == "<unset>"
135180
         assert caps.telemetry_posture["DO_NOT_TRACK"] == "<unset>"
181
+
182
+
183
+class TestCoverageEdges:
184
+    def test_rocm_arch_none_is_not_bf16_capable(self) -> None:
185
+        assert _rocm_arch_supports_bf16(None) is False
186
+
187
+    def test_accelerate_version_missing_returns_none(self) -> None:
188
+        from importlib.metadata import PackageNotFoundError
189
+
190
+        with patch("importlib.metadata.version", side_effect=PackageNotFoundError):
191
+            assert _accelerate_version() is None
tests/unit/hardware/test_plan.pymodified
76 lines changed — click to load
@@ -2,9 +2,13 @@
22
 
33
 from __future__ import annotations
44
 
5
+from unittest.mock import patch
6
+
7
+import pytest
8
+
59
 from dlm.doc.schema import TrainingConfig
610
 from dlm.hardware.capabilities import probe
7
-from dlm.hardware.plan import resolve
11
+from dlm.hardware.plan import _build_reason, resolve
812
 from tests.fixtures.hardware_mocks import force_cpu, force_cuda, force_mps
913
 
1014
 
@@ -102,6 +106,15 @@ class TestAttentionPicker:
102106
         plan = resolve(_cfg(), caps, base_params=135_000_000, seq_len=1024)
103107
         assert plan.attn_implementation == "sdpa"
104108
 
109
+    def test_flash_attention_selected_when_available(self) -> None:
110
+        with (
111
+            patch("dlm.hardware.capabilities._module_available", lambda name: name == "flash_attn"),
112
+            force_cuda(sm=(8, 0)),
113
+        ):
114
+            caps = probe()
115
+        plan = resolve(_cfg(), caps, base_params=1_500_000_000, seq_len=2048)
116
+        assert plan.attn_implementation == "flash_attention_2"
117
+
105118
 
106119
 class TestQloraGating:
107120
     def test_qlora_requested_on_cuda_without_bnb_raises(self) -> None:
@@ -170,6 +183,12 @@ class TestBatchAndGradAccumResolution:
170183
         loose_plan = resolve(_cfg(), loose_caps, base_params=1_500_000_000, seq_len=2048)
171184
         assert tight_plan.micro_batch_size <= loose_plan.micro_batch_size
172185
 
186
+    def test_tiny_budget_breaks_auto_micro_batch_at_one(self) -> None:
187
+        with force_cuda(sm=(8, 0), vram_gb=2.0):
188
+            caps = probe()
189
+        plan = resolve(_cfg(), caps, base_params=3_000_000_000, seq_len=4096)
190
+        assert plan.micro_batch_size == 1
191
+
173192
 
174193
 class TestGradientCheckpointing:
175194
     def test_enabled_when_memory_tight(self) -> None:
@@ -248,3 +267,35 @@ class TestPlanSerialization:
248267
         plan = resolve(_cfg(), caps, base_params=1_500_000_000, seq_len=2048)
249268
         assert "precision=bf16" in plan.reason
250269
         assert "attn=" in plan.reason
270
+
271
+
272
+class TestResolverCoverageEdges:
273
+    def test_world_size_must_be_positive(self) -> None:
274
+        with force_cuda(sm=(8, 0)):
275
+            caps = probe()
276
+        with pytest.raises(ValueError, match="world_size must be >= 1"):
277
+            resolve(_cfg(), caps, base_params=135_000_000, seq_len=512, world_size=0)
278
+
279
+    def test_multi_gpu_refusals_checked_when_world_size_gt_one(self) -> None:
280
+        with force_cuda(sm=(8, 0)):
281
+            caps = probe()
282
+        with (
283
+            patch("dlm.hardware.plan.check_multi_gpu_refusals") as multi_gpu,
284
+            patch("dlm.hardware.plan.check_refusals"),
285
+        ):
286
+            resolve(_cfg(), caps, base_params=135_000_000, seq_len=512, world_size=2)
287
+        multi_gpu.assert_called_once_with(caps, 2)
288
+
289
+    def test_build_reason_records_dora_and_galore_warning(self) -> None:
290
+        reason = _build_reason(
291
+            "bf16",
292
+            "sdpa",
293
+            False,
294
+            True,
295
+            adapter="dora",
296
+            optimizer="galore_adamw",
297
+            base_params=500_000_000,
298
+        )
299
+        assert "adapter=dora" in reason
300
+        assert "optim=galore_adamw" in reason
301
+        assert "warn=galore-small-base(500M<1B)" in reason
tests/unit/harvest/test_sway_reader.pymodified
18 lines changed — click to load
@@ -119,6 +119,18 @@ class TestHappyPath:
119119
         assert len(candidates) == 1
120120
         assert candidates[0].confidence == 1.0
121121
 
122
+    def test_invalid_confidence_defaults_to_one(self, tmp_path: Path) -> None:
123
+        broken_conf = {**_PROBE_FAIL_WITH_REF}
124
+        broken_conf["evidence"] = {
125
+            "prompt": "q?",
126
+            "reference": "a.",
127
+            "confidence": {"not": "numeric"},
128
+        }
129
+        report = _write(tmp_path, _full_report([broken_conf]))
130
+        candidates = read_sway_report(report)
131
+        assert len(candidates) == 1
132
+        assert candidates[0].confidence == 1.0
133
+
122134
 
123135
 class TestMissingReference:
124136
     def test_strict_raises(self, tmp_path: Path) -> None:
tests/unit/inference/test_mlx_backend.pymodified
30 lines changed — click to load
@@ -10,7 +10,11 @@ from types import ModuleType, SimpleNamespace
1010
 import pytest
1111
 
1212
 from dlm.base_models import BASE_MODELS
13
-from dlm.inference.backends.mlx_backend import MlxBackend, _resolve_base_num_hidden_layers
13
+from dlm.inference.backends.mlx_backend import (
14
+    MlxBackend,
15
+    _resolve_base_num_hidden_layers,
16
+    stage_mlx_adapter_dir,
17
+)
1418
 from dlm.inference.errors import AdapterNotFoundError
1519
 from dlm.inference.mlx_adapter import MlxConversionError
1620
 
@@ -123,3 +127,18 @@ class TestMlxBackend:
123127
         assert backend._workdir is None
124128
         assert backend._model is None
125129
         assert backend._tokenizer is None
130
+
131
+
132
+class TestStageMlxAdapterDir:
133
+    def test_unreadable_adapter_config_raises_conversion_error(self, tmp_path: Path) -> None:
134
+        adapter_dir = tmp_path / "adapter"
135
+        adapter_dir.mkdir()
136
+        (adapter_dir / "adapter_model.safetensors").write_bytes(b"fake")
137
+        (adapter_dir / "adapter_config.json").mkdir()
138
+
139
+        with pytest.raises(MlxConversionError, match="cannot read .*adapter_config.json"):
140
+            stage_mlx_adapter_dir(
141
+                adapter_dir,
142
+                tmp_path / "staged",
143
+                base_hf_id=BASE_MODELS["smollm2-135m"].hf_id,
144
+            )
tests/unit/lock/test_mismatch_policy.pymodified
13 lines changed — click to load
@@ -272,6 +272,13 @@ class TestLicenseAcceptanceRule:
272272
         msgs = [msg for _s, msg in classify_mismatches(prior, current)]
273273
         assert any("url changed" in m for m in msgs)
274274
 
275
+    def test_equal_acceptance_is_silent(self) -> None:
276
+        acceptance = self._acceptance()
277
+        prior = _lock(license_acceptance=acceptance)
278
+        current = _lock(license_acceptance=acceptance)
279
+        msgs = [msg for _s, msg in classify_mismatches(prior, current)]
280
+        assert not any("license_acceptance" in m for m in msgs)
281
+
275282
     def test_both_none_is_silent(self) -> None:
276283
         prior = _lock(license_acceptance=None)
277284
         current = _lock(license_acceptance=None)
tests/unit/metrics/test_queries.pymodified
256 lines changed — click to load
@@ -2,15 +2,29 @@
22
 
33
 from __future__ import annotations
44
 
5
+import sqlite3
56
 from datetime import UTC, datetime, timedelta
67
 from pathlib import Path
78
 
8
-from dlm.metrics.events import EvalEvent, PreferenceMineEvent, RunEnd, RunStart, StepEvent
9
+import pytest
10
+
11
+from dlm.metrics.events import (
12
+    EvalEvent,
13
+    GateEvent,
14
+    PreferenceMineEvent,
15
+    RunEnd,
16
+    RunStart,
17
+    StepEvent,
18
+    TokenizationEvent,
19
+)
920
 from dlm.metrics.queries import (
1021
     evals_for_run,
1122
     evals_to_dict,
23
+    gate_events_for_run,
24
+    latest_gate_events,
1225
     latest_preference_mining,
1326
     latest_run_id,
27
+    latest_tokenization,
1428
     preference_mining_for_run,
1529
     preference_mining_to_dict,
1630
     preference_mining_totals,
@@ -18,6 +32,7 @@ from dlm.metrics.queries import (
1832
     runs_to_dict,
1933
     steps_for_run,
2034
     steps_to_dict,
35
+    tokenization_for_run,
2136
 )
2237
 from dlm.metrics.recorder import MetricsRecorder
2338
 
@@ -31,6 +46,34 @@ def _seed(store_root: Path) -> None:
3146
             rec.record_step(StepEvent(run_id=run_id, step=step, loss=2.0 - 0.1 * step))
3247
         rec.record_eval(EvalEvent(run_id=run_id, step=30, val_loss=1.5))
3348
         rec.record_run_end(RunEnd(run_id=run_id, status="ok"))
49
+    rec.record_tokenization(
50
+        TokenizationEvent(
51
+            run_id=3,
52
+            total_sections=10,
53
+            cache_hits=7,
54
+            cache_misses=3,
55
+            total_tokenize_seconds=0.75,
56
+            cache_bytes_after=4096,
57
+        )
58
+    )
59
+    rec.record_gate(
60
+        GateEvent(
61
+            run_id=2,
62
+            adapter_name="tone",
63
+            mean_weight=0.8,
64
+            sample_count=12,
65
+            mode="trained",
66
+        )
67
+    )
68
+    rec.record_gate(
69
+        GateEvent(
70
+            run_id=2,
71
+            adapter_name="facts",
72
+            mean_weight=0.2,
73
+            sample_count=12,
74
+            mode="trained",
75
+        )
76
+    )
3477
     rec.record_preference_mine(
3578
         PreferenceMineEvent(
3679
             run_id=2,
@@ -124,6 +167,133 @@ class TestLatestRunId:
124167
             pass
125168
         assert latest_run_id(tmp_path) is None
126169
 
170
+    def test_none_on_sqlite_error(
171
+        self,
172
+        tmp_path: Path,
173
+        monkeypatch: pytest.MonkeyPatch,
174
+    ) -> None:
175
+        import dlm.metrics.queries as queries_mod
176
+
177
+        def _boom(_store_root: Path) -> sqlite3.Connection:
178
+            raise sqlite3.OperationalError("boom")
179
+
180
+        monkeypatch.setattr(queries_mod, "connect", _boom)
181
+        assert latest_run_id(tmp_path) is None
182
+
183
+
184
+class TestTokenizationQueries:
185
+    def test_tokenization_for_run_returns_row_with_hit_rate(self, tmp_path: Path) -> None:
186
+        _seed(tmp_path)
187
+        row = tokenization_for_run(tmp_path, run_id=3)
188
+        assert row is not None
189
+        assert row.cache_hits == 7
190
+        assert row.hit_rate == 0.7
191
+
192
+    def test_tokenization_for_run_none_when_table_has_no_row(self, tmp_path: Path) -> None:
193
+        from dlm.metrics.db import connect
194
+
195
+        with connect(tmp_path) as _conn:
196
+            pass
197
+        assert tokenization_for_run(tmp_path, run_id=3) is None
198
+
199
+    def test_hit_rate_zero_when_total_lookups_is_zero(self) -> None:
200
+        from dlm.metrics.queries import TokenizationRow
201
+
202
+        row = TokenizationRow(
203
+            run_id=1,
204
+            total_sections=0,
205
+            cache_hits=0,
206
+            cache_misses=0,
207
+            total_tokenize_seconds=0.0,
208
+            cache_bytes_after=0,
209
+            at="2026-01-01T00:00:00Z",
210
+        )
211
+        assert row.hit_rate == 0.0
212
+
213
+    def test_tokenization_for_run_none_on_sqlite_error(
214
+        self,
215
+        tmp_path: Path,
216
+        monkeypatch: pytest.MonkeyPatch,
217
+    ) -> None:
218
+        import dlm.metrics.queries as queries_mod
219
+
220
+        def _boom(_store_root: Path) -> sqlite3.Connection:
221
+            raise sqlite3.OperationalError("boom")
222
+
223
+        monkeypatch.setattr(queries_mod, "connect", _boom)
224
+        assert tokenization_for_run(tmp_path, run_id=1) is None
225
+
226
+    def test_latest_tokenization_returns_most_recent_row(self, tmp_path: Path) -> None:
227
+        _seed(tmp_path)
228
+        row = latest_tokenization(tmp_path)
229
+        assert row is not None
230
+        assert row.run_id == 3
231
+
232
+    def test_latest_tokenization_none_when_empty(self, tmp_path: Path) -> None:
233
+        from dlm.metrics.db import connect
234
+
235
+        with connect(tmp_path) as _conn:
236
+            pass
237
+        assert latest_tokenization(tmp_path) is None
238
+
239
+    def test_latest_tokenization_none_on_sqlite_error(
240
+        self,
241
+        tmp_path: Path,
242
+        monkeypatch: pytest.MonkeyPatch,
243
+    ) -> None:
244
+        import dlm.metrics.queries as queries_mod
245
+
246
+        def _boom(_store_root: Path) -> sqlite3.Connection:
247
+            raise sqlite3.OperationalError("boom")
248
+
249
+        monkeypatch.setattr(queries_mod, "connect", _boom)
250
+        assert latest_tokenization(tmp_path) is None
251
+
252
+
253
+class TestGateQueries:
254
+    def test_gate_events_for_run_returns_rows_sorted_by_adapter(self, tmp_path: Path) -> None:
255
+        _seed(tmp_path)
256
+        rows = gate_events_for_run(tmp_path, run_id=2)
257
+        assert [row.adapter_name for row in rows] == ["facts", "tone"]
258
+
259
+    def test_gate_events_for_run_returns_empty_on_sqlite_error(
260
+        self,
261
+        tmp_path: Path,
262
+        monkeypatch: pytest.MonkeyPatch,
263
+    ) -> None:
264
+        import dlm.metrics.queries as queries_mod
265
+
266
+        def _boom(_store_root: Path) -> sqlite3.Connection:
267
+            raise sqlite3.OperationalError("boom")
268
+
269
+        monkeypatch.setattr(queries_mod, "connect", _boom)
270
+        assert gate_events_for_run(tmp_path, run_id=2) == []
271
+
272
+    def test_latest_gate_events_returns_latest_run_rows(self, tmp_path: Path) -> None:
273
+        _seed(tmp_path)
274
+        rows = latest_gate_events(tmp_path)
275
+        assert [row.adapter_name for row in rows] == ["facts", "tone"]
276
+
277
+    def test_latest_gate_events_empty_when_table_empty(self, tmp_path: Path) -> None:
278
+        from dlm.metrics.db import connect
279
+
280
+        with connect(tmp_path) as _conn:
281
+            pass
282
+        assert latest_gate_events(tmp_path) == []
283
+
284
+    def test_latest_gate_events_empty_on_sqlite_error(
285
+        self,
286
+        tmp_path: Path,
287
+        monkeypatch: pytest.MonkeyPatch,
288
+    ) -> None:
289
+        import dlm.metrics.queries as queries_mod
290
+
291
+        def _boom(_store_root: Path) -> sqlite3.Connection:
292
+            raise sqlite3.OperationalError("boom")
293
+
294
+        monkeypatch.setattr(queries_mod, "connect", _boom)
295
+        assert latest_gate_events(tmp_path) == []
296
+
127297
 
128298
 class TestPreferenceMiningQueries:
129299
     def test_preference_mining_for_run_returns_oldest_first(self, tmp_path: Path) -> None:
@@ -155,6 +325,52 @@ class TestPreferenceMiningQueries:
155325
         assert totals.total_mined_pairs == 3
156326
         assert totals.total_skipped_prompts == 3
157327
 
328
+    def test_preference_mining_for_run_returns_empty_on_sqlite_error(
329
+        self,
330
+        tmp_path: Path,
331
+        monkeypatch: pytest.MonkeyPatch,
332
+    ) -> None:
333
+        import dlm.metrics.queries as queries_mod
334
+
335
+        def _boom(_store_root: Path) -> sqlite3.Connection:
336
+            raise sqlite3.OperationalError("boom")
337
+
338
+        monkeypatch.setattr(queries_mod, "connect", _boom)
339
+        assert preference_mining_for_run(tmp_path, run_id=2) == []
340
+
341
+    def test_latest_preference_mining_returns_none_on_sqlite_error(
342
+        self,
343
+        tmp_path: Path,
344
+        monkeypatch: pytest.MonkeyPatch,
345
+    ) -> None:
346
+        import dlm.metrics.queries as queries_mod
347
+
348
+        def _boom(_store_root: Path) -> sqlite3.Connection:
349
+            raise sqlite3.OperationalError("boom")
350
+
351
+        monkeypatch.setattr(queries_mod, "connect", _boom)
352
+        assert latest_preference_mining(tmp_path) is None
353
+
354
+    def test_preference_mining_totals_none_when_table_empty(self, tmp_path: Path) -> None:
355
+        from dlm.metrics.db import connect
356
+
357
+        with connect(tmp_path) as _conn:
358
+            pass
359
+        assert preference_mining_totals(tmp_path) is None
360
+
361
+    def test_preference_mining_totals_none_on_sqlite_error(
362
+        self,
363
+        tmp_path: Path,
364
+        monkeypatch: pytest.MonkeyPatch,
365
+    ) -> None:
366
+        import dlm.metrics.queries as queries_mod
367
+
368
+        def _boom(_store_root: Path) -> sqlite3.Connection:
369
+            raise sqlite3.OperationalError("boom")
370
+
371
+        monkeypatch.setattr(queries_mod, "connect", _boom)
372
+        assert preference_mining_totals(tmp_path) is None
373
+
158374
 
159375
 class TestDictSerialization:
160376
     def test_runs_to_dict_shape(self, tmp_path: Path) -> None:
tests/unit/metrics/test_recorder.pymodified
84 lines changed — click to load
@@ -7,6 +7,7 @@ import sqlite3
77
 from collections.abc import Iterator
88
 from contextlib import contextmanager
99
 from pathlib import Path
10
+from typing import Any
1011
 
1112
 import pytest
1213
 
@@ -14,15 +15,17 @@ from dlm.metrics.db import metrics_db_path
1415
 from dlm.metrics.events import (
1516
     EvalEvent,
1617
     ExportEvent,
18
+    GateEvent,
1719
     PreferenceMineEvent,
1820
     RunEnd,
1921
     RunStart,
2022
     StepEvent,
23
+    TokenizationEvent,
2124
 )
22
-from dlm.metrics.recorder import DlmTrainerCallback, MetricsRecorder
25
+from dlm.metrics.recorder import DlmTrainerCallback, MetricsRecorder, _maybe_float
2326
 
2427
 
25
-def _select_all(db_path: Path, table: str) -> list[tuple]:
28
+def _select_all(db_path: Path, table: str) -> list[tuple[Any, ...]]:
2629
     conn = sqlite3.connect(str(db_path))
2730
     try:
2831
         rows = conn.execute(f"SELECT * FROM {table} ORDER BY 1").fetchall()
@@ -108,6 +111,43 @@ class TestEvals:
108111
         assert rows[0][2] == 1.8  # val_loss
109112
 
110113
 
114
+class TestTokenization:
115
+    def test_tokenization_written(self, tmp_path: Path) -> None:
116
+        rec = MetricsRecorder(tmp_path)
117
+        rec.record_run_start(RunStart(run_id=1, adapter_version=None, phase="sft", seed=0))
118
+        rec.record_tokenization(
119
+            TokenizationEvent(
120
+                run_id=1,
121
+                total_sections=4,
122
+                cache_hits=3,
123
+                cache_misses=1,
124
+                total_tokenize_seconds=0.25,
125
+                cache_bytes_after=1024,
126
+            )
127
+        )
128
+        rows = _select_all(metrics_db_path(tmp_path), "tokenization")
129
+        assert len(rows) == 1
130
+        assert rows[0][1:6] == (4, 3, 1, 0.25, 1024)
131
+
132
+
133
+class TestGateRecorder:
134
+    def test_gate_written(self, tmp_path: Path) -> None:
135
+        rec = MetricsRecorder(tmp_path)
136
+        rec.record_run_start(RunStart(run_id=1, adapter_version=None, phase="sft", seed=0))
137
+        rec.record_gate(
138
+            GateEvent(
139
+                run_id=1,
140
+                adapter_name="tone",
141
+                mean_weight=0.6,
142
+                sample_count=8,
143
+                mode="trained",
144
+            )
145
+        )
146
+        rows = _select_all(metrics_db_path(tmp_path), "gate_events")
147
+        assert len(rows) == 1
148
+        assert rows[0][1:5] == ("tone", 0.6, 8, "trained")
149
+
150
+
111151
 class TestExports:
112152
     def test_export_written(self, tmp_path: Path) -> None:
113153
         rec = MetricsRecorder(tmp_path)
@@ -236,3 +276,15 @@ class TestTrainerCallbackCompatibility:
236276
 
237277
         with pytest.raises(AttributeError, match="not_a_callback"):
238278
             _ = callback.not_a_callback
279
+
280
+
281
+class TestMaybeFloat:
282
+    def test_none_returns_none(self) -> None:
283
+        assert _maybe_float(None) is None
284
+
285
+    def test_numeric_values_parse(self) -> None:
286
+        assert _maybe_float(1.25) == 1.25
287
+        assert _maybe_float("2.5") == 2.5
288
+
289
+    def test_bad_value_returns_none(self) -> None:
290
+        assert _maybe_float("nope") is None
tests/unit/metrics/test_sinks.pyadded
168 lines changed — click to load
@@ -0,0 +1,168 @@
1
+"""Optional observability sinks: TensorBoard + W&B."""
2
+
3
+from __future__ import annotations
4
+
5
+import sys
6
+from pathlib import Path
7
+from types import ModuleType
8
+from typing import Any
9
+
10
+import pytest
11
+
12
+from dlm.metrics.events import EvalEvent, StepEvent
13
+from dlm.metrics.sinks import (
14
+    TensorBoardSink,
15
+    WandbSink,
16
+    tensorboard_available,
17
+    wandb_available,
18
+)
19
+
20
+
21
+class _FakeWriter:
22
+    def __init__(self, *, log_dir: str) -> None:
23
+        self.log_dir = log_dir
24
+        self.scalars: list[tuple[str, float, int]] = []
25
+        self.flushed = False
26
+        self.closed = False
27
+
28
+    def add_scalar(self, name: str, value: float, step: int) -> None:
29
+        self.scalars.append((name, value, step))
30
+
31
+    def flush(self) -> None:
32
+        self.flushed = True
33
+
34
+    def close(self) -> None:
35
+        self.closed = True
36
+
37
+
38
+class _FakeRun:
39
+    def __init__(self) -> None:
40
+        self.logged: list[tuple[dict[str, float], int]] = []
41
+        self.finished = False
42
+
43
+    def log(self, payload: dict[str, float], *, step: int) -> None:
44
+        self.logged.append((payload, step))
45
+
46
+    def finish(self) -> None:
47
+        self.finished = True
48
+
49
+
50
+class TestAvailabilityProbes:
51
+    def test_tensorboard_available_true_when_spec_exists(
52
+        self,
53
+        monkeypatch: pytest.MonkeyPatch,
54
+    ) -> None:
55
+        monkeypatch.setattr(
56
+            "importlib.util.find_spec",
57
+            lambda name: object() if name == "tensorboard" else None,
58
+        )
59
+        assert tensorboard_available() is True
60
+
61
+    def test_tensorboard_available_false_when_spec_missing(
62
+        self,
63
+        monkeypatch: pytest.MonkeyPatch,
64
+    ) -> None:
65
+        monkeypatch.setattr("importlib.util.find_spec", lambda _name: None)
66
+        assert tensorboard_available() is False
67
+
68
+    def test_wandb_available_true_when_spec_exists(
69
+        self,
70
+        monkeypatch: pytest.MonkeyPatch,
71
+    ) -> None:
72
+        monkeypatch.setattr(
73
+            "importlib.util.find_spec",
74
+            lambda name: object() if name == "wandb" else None,
75
+        )
76
+        assert wandb_available() is True
77
+
78
+    def test_wandb_available_false_when_spec_missing(
79
+        self,
80
+        monkeypatch: pytest.MonkeyPatch,
81
+    ) -> None:
82
+        monkeypatch.setattr("importlib.util.find_spec", lambda _name: None)
83
+        assert wandb_available() is False
84
+
85
+
86
+class TestTensorBoardSink:
87
+    def test_constructor_raises_cleanly_when_tensorboard_missing(
88
+        self,
89
+        tmp_path: Path,
90
+        monkeypatch: pytest.MonkeyPatch,
91
+    ) -> None:
92
+        real_import = __import__
93
+
94
+        def _fake_import(name: str, *args: Any, **kwargs: Any) -> Any:
95
+            if name == "torch.utils.tensorboard":
96
+                raise ImportError("missing tensorboard")
97
+            return real_import(name, *args, **kwargs)
98
+
99
+        monkeypatch.setattr("builtins.__import__", _fake_import)
100
+        with pytest.raises(RuntimeError, match="requires `tensorboard`"):
101
+            TensorBoardSink(tmp_path, run_id=7)
102
+
103
+    def test_records_step_eval_and_close(
104
+        self,
105
+        tmp_path: Path,
106
+        monkeypatch: pytest.MonkeyPatch,
107
+    ) -> None:
108
+        fake_module = ModuleType("torch.utils.tensorboard")
109
+        fake_module.SummaryWriter = _FakeWriter  # type: ignore[attr-defined]
110
+        monkeypatch.setitem(sys.modules, "torch.utils.tensorboard", fake_module)
111
+
112
+        sink = TensorBoardSink(tmp_path, run_id=7)
113
+        sink.record_step(StepEvent(run_id=7, step=10, loss=1.2, lr=0.01, grad_norm=0.5))
114
+        sink.record_eval(EvalEvent(run_id=7, step=10, val_loss=0.9, perplexity=2.0))
115
+        sink.close()
116
+
117
+        writer = sink._writer
118
+        assert sink.log_dir == tmp_path / "tensorboard" / "run_0007"
119
+        assert writer.scalars == [
120
+            ("train/loss", 1.2, 10),
121
+            ("train/lr", 0.01, 10),
122
+            ("train/grad_norm", 0.5, 10),
123
+            ("eval/val_loss", 0.9, 10),
124
+            ("eval/perplexity", 2.0, 10),
125
+        ]
126
+        assert writer.flushed is True
127
+        assert writer.closed is True
128
+
129
+
130
+class TestWandbSink:
131
+    def test_constructor_raises_cleanly_when_wandb_missing(
132
+        self,
133
+        tmp_path: Path,
134
+        monkeypatch: pytest.MonkeyPatch,
135
+    ) -> None:
136
+        real_import = __import__
137
+
138
+        def _fake_import(name: str, *args: Any, **kwargs: Any) -> Any:
139
+            if name == "wandb":
140
+                raise ImportError("missing wandb")
141
+            return real_import(name, *args, **kwargs)
142
+
143
+        monkeypatch.setattr("builtins.__import__", _fake_import)
144
+        with pytest.raises(RuntimeError, match="requires `wandb`"):
145
+            WandbSink(tmp_path, run_id=9, project="dlm")
146
+
147
+    def test_records_payloads_and_close(
148
+        self,
149
+        tmp_path: Path,
150
+        monkeypatch: pytest.MonkeyPatch,
151
+    ) -> None:
152
+        fake_run = _FakeRun()
153
+        fake_wandb = ModuleType("wandb")
154
+        fake_wandb.init = lambda **_kwargs: fake_run  # type: ignore[attr-defined]
155
+        monkeypatch.setitem(sys.modules, "wandb", fake_wandb)
156
+
157
+        sink = WandbSink(tmp_path, run_id=9, project="dlm")
158
+        sink.record_step(StepEvent(run_id=9, step=3, loss=1.0, lr=0.02, grad_norm=0.4))
159
+        sink.record_eval(EvalEvent(run_id=9, step=3, val_loss=0.8, perplexity=1.5))
160
+        sink.record_step(StepEvent(run_id=9, step=4, loss=None, lr=None, grad_norm=None))
161
+        sink.record_eval(EvalEvent(run_id=9, step=4, val_loss=None, perplexity=None))
162
+        sink.close()
163
+
164
+        assert fake_run.logged == [
165
+            ({"train/loss": 1.0, "train/lr": 0.02, "train/grad_norm": 0.4}, 3),
166
+            ({"eval/val_loss": 0.8, "eval/perplexity": 1.5}, 3),
167
+        ]
168
+        assert fake_run.finished is True
tests/unit/modality/test_dispatch_modules.pyadded
136 lines changed — click to load
@@ -0,0 +1,136 @@
1
+"""Direct coverage for modality dispatch wrapper modules."""
2
+
3
+from __future__ import annotations
4
+
5
+from types import SimpleNamespace
6
+from unittest.mock import patch
7
+
8
+import pytest
9
+
10
+from dlm.base_models import BaseModelSpec
11
+from dlm.modality.audio import AudioLanguageModality
12
+from dlm.modality.errors import UnknownModalityError
13
+from dlm.modality.registry import TextModality, _unknown, modality_for
14
+from dlm.modality.text import TextModality as ReexportedTextModality
15
+from dlm.modality.vl import VisionLanguageModality
16
+
17
+
18
+def _minimal_text_spec(*, modality: str = "text") -> BaseModelSpec:
19
+    return BaseModelSpec.model_validate(
20
+        {
21
+            "key": "demo-1b",
22
+            "hf_id": "org/demo-1b",
23
+            "revision": "0123456789abcdef0123456789abcdef01234567",
24
+            "architecture": "DemoForCausalLM",
25
+            "params": 1_000_000_000,
26
+            "target_modules": ["q_proj", "v_proj"],
27
+            "template": "chatml",
28
+            "gguf_arch": "demo",
29
+            "tokenizer_pre": "demo",
30
+            "license_spdx": "Apache-2.0",
31
+            "license_url": None,
32
+            "requires_acceptance": False,
33
+            "redistributable": True,
34
+            "size_gb_fp16": 2.0,
35
+            "context_length": 4096,
36
+            "recommended_seq_len": 2048,
37
+            "modality": modality,
38
+        }
39
+    )
40
+
41
+
42
+def test_text_module_reexports_text_modality() -> None:
43
+    assert ReexportedTextModality is TextModality
44
+
45
+
46
+def test_text_dispatch_defaults_are_noops() -> None:
47
+    dispatch = TextModality()
48
+
49
+    assert dispatch.load_processor(_minimal_text_spec()) is None
50
+    assert (
51
+        dispatch.dispatch_export(
52
+            store=object(),
53
+            spec=_minimal_text_spec(),
54
+            adapter_name=None,
55
+            quant=None,
56
+            merged=False,
57
+            adapter_mix_raw=None,
58
+        )
59
+        is None
60
+    )
61
+
62
+
63
+def test_unknown_error_contains_registration_hint() -> None:
64
+    err = _unknown("mystery")
65
+    assert isinstance(err, UnknownModalityError)
66
+    assert "Register a ModalityDispatch subclass" in str(err)
67
+
68
+
69
+def test_modality_for_unknown_modality_raises() -> None:
70
+    with pytest.raises(UnknownModalityError, match="mystery"):
71
+        modality_for(SimpleNamespace(modality="mystery"))
72
+
73
+
74
+def test_audio_modality_loads_processor_and_dispatches_export() -> None:
75
+    dispatch = AudioLanguageModality()
76
+    spec = SimpleNamespace()
77
+
78
+    with (
79
+        patch("dlm.train.loader.load_processor", return_value="processor") as load_processor,
80
+        patch("dlm.export.dispatch.dispatch_audio_export", return_value="audio-export") as export,
81
+    ):
82
+        processor = dispatch.load_processor(spec)
83
+        result = dispatch.dispatch_export(
84
+            store="store",
85
+            spec=spec,
86
+            adapter_name="adapter",
87
+            quant="q4_k_m",
88
+            merged=False,
89
+            adapter_mix_raw="named",
90
+        )
91
+
92
+    assert processor == "processor"
93
+    load_processor.assert_called_once_with(spec)
94
+    assert result == "audio-export"
95
+    export.assert_called_once_with(
96
+        store="store",
97
+        spec=spec,
98
+        adapter_name="adapter",
99
+        quant="q4_k_m",
100
+        merged=False,
101
+        adapter_mix_raw="named",
102
+    )
103
+
104
+
105
+def test_vl_modality_loads_processor_and_dispatches_export() -> None:
106
+    dispatch = VisionLanguageModality()
107
+    spec = SimpleNamespace()
108
+    context = {"emit": "gguf"}
109
+
110
+    with (
111
+        patch("dlm.train.loader.load_processor", return_value="processor") as load_processor,
112
+        patch("dlm.export.dispatch.dispatch_vl_export", return_value="vl-export") as export,
113
+    ):
114
+        processor = dispatch.load_processor(spec)
115
+        result = dispatch.dispatch_export(
116
+            store="store",
117
+            spec=spec,
118
+            adapter_name="adapter",
119
+            quant="q8_0",
120
+            merged=True,
121
+            adapter_mix_raw=None,
122
+            gguf_emission_context=context,
123
+        )
124
+
125
+    assert processor == "processor"
126
+    load_processor.assert_called_once_with(spec)
127
+    assert result == "vl-export"
128
+    export.assert_called_once_with(
129
+        store="store",
130
+        spec=spec,
131
+        adapter_name="adapter",
132
+        quant="q8_0",
133
+        merged=True,
134
+        adapter_mix_raw=None,
135
+        gguf_emission_context=context,
136
+    )
tests/unit/modality/test_vl_contract.pyadded
72 lines changed — click to load
@@ -0,0 +1,72 @@
1
+"""Direct coverage for VL runtime contract guardrails."""
2
+
3
+from __future__ import annotations
4
+
5
+from types import SimpleNamespace
6
+
7
+import pytest
8
+
9
+from dlm.modality.errors import ProcessorContractError
10
+from dlm.modality.vl_contract import ensure_supported_vl_runtime, validate_loaded_vl_processor
11
+
12
+
13
+def test_ensure_supported_vl_runtime_is_noop_for_non_vl_specs() -> None:
14
+    ensure_supported_vl_runtime(
15
+        SimpleNamespace(modality="text", architecture="Anything", key="demo")
16
+    )
17
+
18
+
19
+def test_ensure_supported_vl_runtime_is_noop_for_supported_vl_architecture() -> None:
20
+    ensure_supported_vl_runtime(
21
+        SimpleNamespace(
22
+            modality="vision-language", architecture="Qwen2VLForConditionalGeneration", key="demo"
23
+        )
24
+    )
25
+
26
+
27
+def test_ensure_supported_vl_runtime_rejects_internvl_family() -> None:
28
+    with pytest.raises(ProcessorContractError, match="InternVL-family VL model"):
29
+        ensure_supported_vl_runtime(
30
+            SimpleNamespace(
31
+                modality="vision-language", architecture="InternVLChatModel", key="internvl"
32
+            )
33
+        )
34
+
35
+
36
+def test_validate_loaded_vl_processor_is_noop_for_non_vl_specs() -> None:
37
+    processor = object()
38
+    assert (
39
+        validate_loaded_vl_processor(
40
+            SimpleNamespace(modality="text", architecture="Demo", key="demo"), processor
41
+        )
42
+        is processor
43
+    )
44
+
45
+
46
+def test_validate_loaded_vl_processor_accepts_processor_with_image_processor() -> None:
47
+    processor = SimpleNamespace(image_processor=object())
48
+    assert (
49
+        validate_loaded_vl_processor(
50
+            SimpleNamespace(modality="vision-language", architecture="Demo", key="demo"),
51
+            processor,
52
+        )
53
+        is processor
54
+    )
55
+
56
+
57
+def test_validate_loaded_vl_processor_delegates_internvl_refusal() -> None:
58
+    with pytest.raises(ProcessorContractError, match="InternVL-family VL model"):
59
+        validate_loaded_vl_processor(
60
+            SimpleNamespace(
61
+                modality="vision-language", architecture="InternVLChatModel", key="internvl"
62
+            ),
63
+            SimpleNamespace(),
64
+        )
65
+
66
+
67
+def test_validate_loaded_vl_processor_rejects_missing_image_processor() -> None:
68
+    with pytest.raises(ProcessorContractError, match="without an `image_processor` attribute"):
69
+        validate_loaded_vl_processor(
70
+            SimpleNamespace(modality="vision-language", architecture="Demo", key="demo"),
71
+            SimpleNamespace(),
72
+        )
tests/unit/preference/test_cli_judge.pymodified
95 lines changed — click to load
@@ -2,6 +2,7 @@
22
 
33
 from __future__ import annotations
44
 
5
+import math
56
 import subprocess
67
 from pathlib import Path
78
 from unittest.mock import patch
@@ -11,11 +12,13 @@ import pytest
1112
 from dlm.preference import (
1213
     CliJudge,
1314
     HfRewardModelJudge,
15
+    InvalidJudgeSpecError,
1416
     JudgeInvocationError,
1517
     JudgeUnavailableError,
1618
     SwayJudge,
1719
     build_judge,
1820
 )
21
+from dlm.preference.judge import _combine_reasoning, _parse_cli_candidate_score
1922
 
2023
 
2124
 def _proc(
@@ -33,6 +36,21 @@ def _proc(
3336
 
3437
 
3538
 class TestCliJudge:
39
+    def test_blank_command_is_rejected(self) -> None:
40
+        with pytest.raises(InvalidJudgeSpecError, match="include a command"):
41
+            CliJudge("   ")
42
+
43
+    def test_empty_argv_after_split_is_rejected(self) -> None:
44
+        with (
45
+            patch("dlm.preference.judge.shlex.split", return_value=[]),
46
+            pytest.raises(InvalidJudgeSpecError, match="include a command"),
47
+        ):
48
+            CliJudge("judge-bin")
49
+
50
+    def test_non_positive_timeout_is_rejected(self) -> None:
51
+        with pytest.raises(ValueError, match="timeout must be > 0"):
52
+            CliJudge("judge-bin", timeout=0.0)
53
+
3654
     def test_scores_pair_via_two_json_round_trips(self) -> None:
3755
         seen_payloads: list[str] = []
3856
 
@@ -114,6 +132,54 @@ class TestCliJudge:
114132
         ):
115133
             judge.score_pair("p", "a", "b")
116134
 
135
+    def test_oserror_raises_unavailable_error(self) -> None:
136
+        judge = CliJudge("judge-bin")
137
+        with (
138
+            patch(
139
+                "dlm.preference.judge.subprocess.run",
140
+                side_effect=OSError("permission denied"),
141
+            ),
142
+            pytest.raises(JudgeUnavailableError, match="could not start"),
143
+        ):
144
+            judge.score_pair("p", "a", "b")
145
+
146
+
147
+class TestCliJudgeHelpers:
148
+    def test_empty_stdout_is_rejected(self) -> None:
149
+        with pytest.raises(JudgeInvocationError, match="empty stdout"):
150
+            _parse_cli_candidate_score("   ")
151
+
152
+    def test_json_must_be_object(self) -> None:
153
+        with pytest.raises(JudgeInvocationError, match="JSON object"):
154
+            _parse_cli_candidate_score('["not", "an", "object"]')
155
+
156
+    @pytest.mark.parametrize("score", [float("nan"), float("inf"), -float("inf")])
157
+    def test_score_must_be_finite(self, score: float) -> None:
158
+        rendered = "NaN" if math.isnan(score) else ("Infinity" if score > 0 else "-Infinity")
159
+        with pytest.raises(JudgeInvocationError, match="must be finite"):
160
+            _parse_cli_candidate_score(f'{{"score": {rendered}}}')
161
+
162
+    def test_reasoning_must_be_string_when_present(self) -> None:
163
+        with pytest.raises(JudgeInvocationError, match="must be a string"):
164
+            _parse_cli_candidate_score('{"score": 1.0, "reasoning": 7}')
165
+
166
+    @pytest.mark.parametrize(
167
+        ("left", "right", "expected"),
168
+        [
169
+            ("why a", None, "a: why a"),
170
+            (None, "why b", "b: why b"),
171
+            ("why a", "why b", "a: why a | b: why b"),
172
+            (None, None, None),
173
+        ],
174
+    )
175
+    def test_combine_reasoning_formats_present_parts(
176
+        self,
177
+        left: str | None,
178
+        right: str | None,
179
+        expected: str | None,
180
+    ) -> None:
181
+        assert _combine_reasoning(left, right) == expected
182
+
117183
 
118184
 class TestBuildJudge:
119185
     def test_cli_ref_builds_concrete_cli_judge(self) -> None:
tests/unit/preference/test_hf_reward_judge.pymodified
265 lines changed — click to load
@@ -2,11 +2,26 @@
22
 
33
 from __future__ import annotations
44
 
5
+import builtins
56
 from collections import deque
7
+from types import SimpleNamespace
8
+from unittest.mock import patch
69
 
710
 import pytest
811
 
9
-from dlm.preference import HfRewardModelJudge, JudgeInvocationError
12
+from dlm.preference import (
13
+    HfRewardModelJudge,
14
+    InvalidJudgeSpecError,
15
+    JudgeInvocationError,
16
+    JudgeUnavailableError,
17
+)
18
+from dlm.preference.judge import (
19
+    _default_reward_loader,
20
+    _encode_reward_input,
21
+    _extract_reward_scalar,
22
+    _move_to_device,
23
+    _resolve_reward_device,
24
+)
1025
 
1126
 
1227
 class FakeScalar:
@@ -37,14 +52,35 @@ class FakeBatch(dict[str, object]):
3752
         return self
3853
 
3954
 
55
+class FakeTensor:
56
+    def __init__(self) -> None:
57
+        self.device: str | None = None
58
+
59
+    def to(self, device: str) -> FakeTensor:
60
+        self.device = device
61
+        return self
62
+
63
+
4064
 class FakeTokenizer:
41
-    def __init__(self, *, use_chat_template: bool = False) -> None:
65
+    def __init__(
66
+        self,
67
+        *,
68
+        use_chat_template: bool = False,
69
+        template_error: Exception | None = None,
70
+        template_returns_non_string: bool = False,
71
+    ) -> None:
4272
         self.calls: list[tuple[str, tuple[object, ...], dict[str, object]]] = []
73
+        self._template_error = template_error
74
+        self._template_returns_non_string = template_returns_non_string
4375
         if use_chat_template:
4476
             self.chat_template = "fake-template"
4577
 
4678
     def apply_chat_template(self, messages: list[dict[str, str]], **kwargs: object) -> str:
4779
         self.calls.append(("apply_chat_template", (messages,), dict(kwargs)))
80
+        if self._template_error is not None:
81
+            raise self._template_error
82
+        if self._template_returns_non_string:
83
+            return ""  # type: ignore[return-value]
4884
         return f"templated::{messages[0]['content']}::{messages[1]['content']}"
4985
 
5086
     def __call__(self, *args: object, **kwargs: object) -> FakeBatch:
@@ -67,6 +103,30 @@ class FakeModel:
67103
         return Output(self._logits.popleft())
68104
 
69105
 
106
+class FakeTorchScalarLogits:
107
+    def __init__(self, value: float) -> None:
108
+        self._value = value
109
+
110
+    def numel(self) -> int:
111
+        return 1
112
+
113
+    def item(self) -> float:
114
+        return self._value
115
+
116
+
117
+class FakePretrainedRewardModel:
118
+    def __init__(self) -> None:
119
+        self.device: str | None = None
120
+        self.eval_called = False
121
+
122
+    def to(self, device: str) -> FakePretrainedRewardModel:
123
+        self.device = device
124
+        return self
125
+
126
+    def eval(self) -> None:
127
+        self.eval_called = True
128
+
129
+
70130
 def _loader_factory(tokenizer: FakeTokenizer, model: FakeModel):
71131
     calls: list[tuple[str, str]] = []
72132
 
@@ -80,6 +140,10 @@ def _loader_factory(tokenizer: FakeTokenizer, model: FakeModel):
80140
 
81141
 
82142
 class TestHfRewardModelJudge:
143
+    def test_blank_selector_is_rejected(self) -> None:
144
+        with pytest.raises(InvalidJudgeSpecError, match="include a model id"):
145
+            HfRewardModelJudge("   ")
146
+
83147
     def test_scores_pair_and_caches_loaded_bundle(self) -> None:
84148
         tokenizer = FakeTokenizer()
85149
         model = FakeModel([FakeLogits([0.2]), FakeLogits([0.9])])
@@ -142,3 +206,162 @@ class TestHfRewardModelJudge:
142206
         with pytest.raises(JudgeInvocationError, match="no `.logits`"):
143207
             judge.score_pair("prompt", "a", "b")
144208
         assert calls == [("reward/model", "cpu")]
209
+
210
+    def test_missing_torch_is_reported(self) -> None:
211
+        tokenizer = FakeTokenizer()
212
+        model = FakeModel([FakeLogits([0.2]), FakeLogits([0.1])])
213
+        _, loader = _loader_factory(tokenizer, model)
214
+        judge = HfRewardModelJudge("reward/model", device="cpu", loader=loader)
215
+        real_import = builtins.__import__
216
+
217
+        def fake_import(name: str, *args: object, **kwargs: object):
218
+            if name == "torch":
219
+                raise ImportError("no torch here")
220
+            return real_import(name, *args, **kwargs)
221
+
222
+        with (
223
+            patch("builtins.__import__", side_effect=fake_import),
224
+            pytest.raises(JudgeUnavailableError, match="requires torch"),
225
+        ):
226
+            judge.score_pair("prompt", "a", "b")
227
+
228
+    def test_default_loader_path_is_used_when_no_loader_is_supplied(self) -> None:
229
+        tokenizer = FakeTokenizer()
230
+        model = FakeModel([FakeLogits([0.7]), FakeLogits([0.1])])
231
+
232
+        def fake_default_loader(hf_id: str, device: str):
233
+            from dlm.preference.judge import _LoadedRewardJudge
234
+
235
+            assert hf_id == "reward/model"
236
+            assert device == "cpu"
237
+            return _LoadedRewardJudge(model=model, tokenizer=tokenizer, device=device)
238
+
239
+        judge = HfRewardModelJudge("reward/model", device="cpu")
240
+        with patch("dlm.preference.judge._default_reward_loader", side_effect=fake_default_loader):
241
+            score = judge.score_pair("prompt", "a", "b")
242
+
243
+        assert score.preferred == "a"
244
+
245
+
246
+class TestHfRewardHelpers:
247
+    def test_default_reward_loader_requires_transformers(self) -> None:
248
+        real_import = builtins.__import__
249
+
250
+        def fake_import(name: str, *args: object, **kwargs: object):
251
+            if name == "transformers":
252
+                raise ImportError("missing transformers")
253
+            return real_import(name, *args, **kwargs)
254
+
255
+        with (
256
+            patch("builtins.__import__", side_effect=fake_import),
257
+            pytest.raises(JudgeUnavailableError, match="requires transformers"),
258
+        ):
259
+            _default_reward_loader("reward/model", "cpu")
260
+
261
+    def test_default_reward_loader_moves_model_and_sets_eval(self) -> None:
262
+        model = FakePretrainedRewardModel()
263
+        tokenizer = FakeTokenizer()
264
+
265
+        class AutoModelForSequenceClassification:
266
+            @staticmethod
267
+            def from_pretrained(hf_id: str) -> FakePretrainedRewardModel:
268
+                assert hf_id == "reward/model"
269
+                return model
270
+
271
+        class AutoTokenizer:
272
+            @staticmethod
273
+            def from_pretrained(hf_id: str) -> FakeTokenizer:
274
+                assert hf_id == "reward/model"
275
+                return tokenizer
276
+
277
+        fake_transformers = SimpleNamespace(
278
+            AutoModelForSequenceClassification=AutoModelForSequenceClassification,
279
+            AutoTokenizer=AutoTokenizer,
280
+        )
281
+
282
+        with patch.dict("sys.modules", {"transformers": fake_transformers}):
283
+            loaded = _default_reward_loader("reward/model", "mps")
284
+
285
+        assert loaded.model is model
286
+        assert loaded.tokenizer is tokenizer
287
+        assert loaded.device == "mps"
288
+        assert model.device == "mps"
289
+        assert model.eval_called is True
290
+
291
+    def test_resolve_reward_device_respects_explicit_request(self) -> None:
292
+        assert _resolve_reward_device("cuda:3") == "cuda:3"
293
+
294
+    def test_resolve_reward_device_returns_cpu_when_torch_is_missing(self) -> None:
295
+        real_import = builtins.__import__
296
+
297
+        def fake_import(name: str, *args: object, **kwargs: object):
298
+            if name == "torch":
299
+                raise ImportError("no torch")
300
+            return real_import(name, *args, **kwargs)
301
+
302
+        with patch("builtins.__import__", side_effect=fake_import):
303
+            assert _resolve_reward_device("auto") == "cpu"
304
+
305
+    def test_resolve_reward_device_prefers_cuda_then_mps_then_cpu(self) -> None:
306
+        torch_cuda = SimpleNamespace(
307
+            cuda=SimpleNamespace(is_available=lambda: True),
308
+            backends=SimpleNamespace(mps=SimpleNamespace(is_available=lambda: True)),
309
+        )
310
+        torch_mps = SimpleNamespace(
311
+            cuda=SimpleNamespace(is_available=lambda: False),
312
+            backends=SimpleNamespace(mps=SimpleNamespace(is_available=lambda: True)),
313
+        )
314
+        torch_cpu = SimpleNamespace(
315
+            cuda=SimpleNamespace(is_available=lambda: False),
316
+            backends=SimpleNamespace(mps=SimpleNamespace(is_available=lambda: False)),
317
+        )
318
+
319
+        with patch.dict("sys.modules", {"torch": torch_cuda}):
320
+            assert _resolve_reward_device("auto") == "cuda"
321
+        with patch.dict("sys.modules", {"torch": torch_mps}):
322
+            assert _resolve_reward_device("auto") == "mps"
323
+        with patch.dict("sys.modules", {"torch": torch_cpu}):
324
+            assert _resolve_reward_device("auto") == "cpu"
325
+
326
+    def test_encode_reward_input_falls_back_when_template_raises(self) -> None:
327
+        tokenizer = FakeTokenizer(use_chat_template=True, template_error=RuntimeError("boom"))
328
+
329
+        encoded = _encode_reward_input(tokenizer, "prompt", "candidate")
330
+
331
+        assert isinstance(encoded, FakeBatch)
332
+        assert tokenizer.calls[-1][0] == "tokenizer"
333
+        assert tokenizer.calls[-1][1] == ("prompt",)
334
+        assert tokenizer.calls[-1][2]["text_pair"] == "candidate"
335
+
336
+    def test_encode_reward_input_falls_back_when_template_returns_non_string(self) -> None:
337
+        tokenizer = FakeTokenizer(use_chat_template=True, template_returns_non_string=True)
338
+
339
+        encoded = _encode_reward_input(tokenizer, "prompt", "candidate")
340
+
341
+        assert isinstance(encoded, FakeBatch)
342
+        assert tokenizer.calls[-1][0] == "tokenizer"
343
+
344
+    def test_move_to_device_moves_mapping_values(self) -> None:
345
+        tensor = FakeTensor()
346
+        payload = {"input_ids": tensor, "meta": "keep"}
347
+
348
+        moved = _move_to_device(payload, "mps")
349
+
350
+        assert moved["input_ids"] is tensor
351
+        assert tensor.device == "mps"
352
+        assert moved["meta"] == "keep"
353
+
354
+    def test_move_to_device_returns_unmodified_non_mapping_values(self) -> None:
355
+        value = object()
356
+        assert _move_to_device(value, "cpu") is value
357
+
358
+    def test_extract_reward_scalar_uses_item_fallback(self) -> None:
359
+        assert _extract_reward_scalar(FakeTorchScalarLogits(0.75)) == pytest.approx(0.75)
360
+
361
+    def test_extract_reward_scalar_rejects_unreadable_values(self) -> None:
362
+        class UnreadableLogits:
363
+            def numel(self) -> int:
364
+                return 1
365
+
366
+        with pytest.raises(JudgeInvocationError, match="unreadable scalar logit"):
367
+            _extract_reward_scalar(UnreadableLogits())
tests/unit/preference/test_mine_dedup.pymodified
90 lines changed — click to load
@@ -4,6 +4,8 @@ from __future__ import annotations
44
 
55
 from collections import deque
66
 
7
+import pytest
8
+
79
 from dlm.doc.parser import parse_text
810
 from dlm.preference import (
911
     PreferenceMineSkipReason,
@@ -11,6 +13,7 @@ from dlm.preference import (
1113
     render_mine_plan,
1214
 )
1315
 from dlm.preference.judge import PairScore
16
+from dlm.preference.mine import _best_pair, _first_line, _resolve_pair, _unique_nonempty
1417
 
1518
 _FRONTMATTER = """---
1619
 dlm_id: 01KPQ9X1000000000000000000
@@ -44,6 +47,20 @@ def _parsed(body: str):
4447
 
4548
 
4649
 class TestBuildMinePlan:
50
+    def test_validates_numeric_limits(self) -> None:
51
+        parsed = _parsed("::instruction::\n### Q\nquestion?\n### A\nreference\n")
52
+        backend = StubBackend({"question?": ["one", "two"]})
53
+        judge = StubJudge({("question?", "one", "two"): PairScore(score_a=1.0, score_b=0.0)})
54
+
55
+        with pytest.raises(ValueError, match="samples must be >= 2"):
56
+            build_mine_plan(parsed, backend, judge, mined_run_id=1, samples=1)
57
+        with pytest.raises(ValueError, match="max_pairs must be >= 1"):
58
+            build_mine_plan(parsed, backend, judge, mined_run_id=1, samples=2, max_pairs=0)
59
+        with pytest.raises(ValueError, match="threshold must be >= 0.0"):
60
+            build_mine_plan(parsed, backend, judge, mined_run_id=1, samples=2, threshold=-0.1)
61
+        with pytest.raises(ValueError, match="max_new_tokens must be >= 1"):
62
+            build_mine_plan(parsed, backend, judge, mined_run_id=1, samples=2, max_new_tokens=0)
63
+
4764
     def test_materializes_auto_mined_preference_section(self) -> None:
4865
         parsed = _parsed("::instruction::\n### Q\nquestion?\n### A\nreference\n")
4966
         backend = StubBackend({"question?": ["bad answer", "good answer"]})
@@ -192,3 +209,55 @@ class TestBuildMinePlan:
192209
         assert plan.additions == ()
193210
         assert len(plan.skipped) == 1
194211
         assert plan.skipped[0].reason is PreferenceMineSkipReason.MALFORMED_INSTRUCTION
212
+
213
+    def test_stops_collecting_once_max_pairs_is_reached(self) -> None:
214
+        parsed = _parsed(
215
+            "::instruction::\n### Q\nquestion one?\n### A\nreference\n\n"
216
+            "::instruction::\n### Q\nquestion two?\n### A\nreference\n"
217
+        )
218
+        backend = StubBackend(
219
+            {
220
+                "question one?": ["bad one", "good one"],
221
+                "question two?": ["bad two", "good two"],
222
+            }
223
+        )
224
+        judge = StubJudge(
225
+            {
226
+                ("question one?", "bad one", "good one"): PairScore(score_a=0.1, score_b=0.9),
227
+                ("question two?", "bad two", "good two"): PairScore(score_a=0.1, score_b=0.9),
228
+            }
229
+        )
230
+
231
+        plan = build_mine_plan(parsed, backend, judge, mined_run_id=4, samples=2, max_pairs=1)
232
+
233
+        assert len(plan.additions) == 1
234
+        assert plan.additions[0].source.prompt == "question one?"
235
+
236
+    def test_insufficient_variety_is_reported(self) -> None:
237
+        parsed = _parsed("::instruction::\n### Q\nquestion?\n### A\nreference\n")
238
+        backend = StubBackend({"question?": [" same ", "", "same", "   "]})
239
+        judge = StubJudge({})
240
+
241
+        plan = build_mine_plan(parsed, backend, judge, mined_run_id=6, samples=4)
242
+
243
+        assert plan.additions == ()
244
+        assert len(plan.skipped) == 1
245
+        assert plan.skipped[0].reason is PreferenceMineSkipReason.INSUFFICIENT_VARIETY
246
+        assert "need at least 2 unique non-empty candidates" in plan.skipped[0].detail
247
+
248
+
249
+class TestMineHelpers:
250
+    def test_unique_nonempty_strips_blanks_and_duplicates(self) -> None:
251
+        assert _unique_nonempty(["", " alpha ", "alpha", "beta", "   "]) == ["alpha", "beta"]
252
+
253
+    def test_best_pair_skips_ties(self) -> None:
254
+        judge = StubJudge({("prompt", "a", "b"): PairScore(score_a=0.4, score_b=0.4)})
255
+
256
+        assert _best_pair("prompt", ["a", "b"], judge=judge) is None
257
+
258
+    def test_resolve_pair_returns_none_for_ties(self) -> None:
259
+        assert _resolve_pair("a", "b", PairScore(score_a=0.2, score_b=0.2)) is None
260
+
261
+    def test_first_line_truncates_long_text(self) -> None:
262
+        rendered = _first_line("x" * 90, max_chars=20)
263
+        assert rendered == ("x" * 19) + "…"
tests/unit/preference/test_pending.pyadded
199 lines changed — click to load
@@ -0,0 +1,199 @@
1
+"""Tests for staged preference pending-plan helpers."""
2
+
3
+from __future__ import annotations
4
+
5
+import json
6
+from pathlib import Path
7
+
8
+import pytest
9
+
10
+from dlm.doc.sections import Section, SectionType
11
+from dlm.preference.pending import (
12
+    PendingPreferencePlanError,
13
+    _optional_float,
14
+    _optional_int,
15
+    _optional_str,
16
+    _section_from_payload,
17
+    clear_pending_plan,
18
+    load_pending_plan,
19
+    pending_plan_path,
20
+    save_pending_plan,
21
+)
22
+from dlm.store.paths import for_dlm
23
+
24
+_DLM_ID = "01KPQ9X1000000000000000000"
25
+
26
+
27
+def _mined_pref(
28
+    *,
29
+    prompt: str = "question?",
30
+    chosen: str = "better",
31
+    rejected: str = "worse",
32
+    run_id: int = 7,
33
+) -> Section:
34
+    body = f"### Prompt\n{prompt}\n### Chosen\n{chosen}\n### Rejected\n{rejected}"
35
+    return Section(
36
+        type=SectionType.PREFERENCE,
37
+        content=body,
38
+        start_line=12,
39
+        adapter="tone",
40
+        tags={"topic": "blas"},
41
+        auto_mined=True,
42
+        judge_name="sway:preference_judge",
43
+        judge_score_chosen=0.9,
44
+        judge_score_rejected=0.1,
45
+        mined_at="2026-04-23T20:00:00Z",
46
+        mined_run_id=run_id,
47
+    )
48
+
49
+
50
+def _image() -> Section:
51
+    return Section(
52
+        type=SectionType.IMAGE,
53
+        content="A DGEMM block diagram.",
54
+        media_path="diagram.png",
55
+        media_alt="DGEMM diagram",
56
+        media_blob_sha="ab" * 32,
57
+    )
58
+
59
+
60
+class TestPendingPlan:
61
+    def test_pending_path_round_trip_and_clear(self, tmp_path: Path) -> None:
62
+        home = tmp_path / "home"
63
+        source_path = tmp_path / "doc.dlm"
64
+        source_path.write_text("stub", encoding="utf-8")
65
+        store = for_dlm(_DLM_ID, home=home)
66
+
67
+        path = pending_plan_path(store)
68
+        assert path == home / "store" / _DLM_ID / "preference" / "pending.json"
69
+
70
+        saved = save_pending_plan(
71
+            store,
72
+            source_path=source_path,
73
+            sections=[_mined_pref(), _image()],
74
+        )
75
+        raw = json.loads(path.read_text(encoding="utf-8"))
76
+        loaded = load_pending_plan(store)
77
+
78
+        assert saved.source_path == source_path.resolve()
79
+        assert saved.created_at.endswith("Z")
80
+        assert raw["schema_version"] == 1
81
+        assert raw["source_path"] == str(source_path.resolve())
82
+        assert loaded == saved
83
+        assert clear_pending_plan(store) is True
84
+        assert clear_pending_plan(store) is False
85
+        assert load_pending_plan(store) is None
86
+
87
+    def test_load_returns_none_when_plan_absent(self, tmp_path: Path) -> None:
88
+        store = for_dlm(_DLM_ID, home=tmp_path / "home")
89
+
90
+        assert load_pending_plan(store) is None
91
+
92
+    def test_load_rejects_unreadable_plan(
93
+        self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
94
+    ) -> None:
95
+        store = for_dlm(_DLM_ID, home=tmp_path / "home")
96
+        path = pending_plan_path(store)
97
+        path.parent.mkdir(parents=True, exist_ok=True)
98
+        path.write_text("{}", encoding="utf-8")
99
+
100
+        def _raise(_self: Path, *, encoding: str) -> str:
101
+            _ = encoding
102
+            raise OSError("boom")
103
+
104
+        monkeypatch.setattr(Path, "read_text", _raise)
105
+        with pytest.raises(
106
+            PendingPreferencePlanError, match="could not read staged preference plan"
107
+        ):
108
+            load_pending_plan(store)
109
+
110
+    @pytest.mark.parametrize(
111
+        ("payload", "message"),
112
+        [
113
+            (["not", "an", "object"], "must be a JSON object"),
114
+            ({"schema_version": 2}, "unsupported staged preference plan schema_version=2"),
115
+            (
116
+                {"schema_version": 1, "created_at": "2026-04-24T20:00:00Z", "sections": []},
117
+                "missing source_path",
118
+            ),
119
+            (
120
+                {"schema_version": 1, "source_path": "/tmp/doc.dlm", "sections": []},
121
+                "missing created_at",
122
+            ),
123
+            (
124
+                {
125
+                    "schema_version": 1,
126
+                    "source_path": "/tmp/doc.dlm",
127
+                    "created_at": "2026-04-24T20:00:00Z",
128
+                },
129
+                "missing sections",
130
+            ),
131
+            (
132
+                {
133
+                    "schema_version": 1,
134
+                    "source_path": "/tmp/doc.dlm",
135
+                    "created_at": "2026-04-24T20:00:00Z",
136
+                    "sections": [{"content": "oops"}],
137
+                },
138
+                "invalid section payload at index 0",
139
+            ),
140
+        ],
141
+    )
142
+    def test_load_rejects_invalid_payloads(
143
+        self, tmp_path: Path, payload: object, message: str
144
+    ) -> None:
145
+        store = for_dlm(_DLM_ID, home=tmp_path / "home")
146
+        path = pending_plan_path(store)
147
+        path.parent.mkdir(parents=True, exist_ok=True)
148
+        path.write_text(json.dumps(payload), encoding="utf-8")
149
+
150
+        with pytest.raises(PendingPreferencePlanError, match=message):
151
+            load_pending_plan(store)
152
+
153
+    def test_load_rejects_invalid_json(self, tmp_path: Path) -> None:
154
+        store = for_dlm(_DLM_ID, home=tmp_path / "home")
155
+        path = pending_plan_path(store)
156
+        path.parent.mkdir(parents=True, exist_ok=True)
157
+        path.write_text("{not json", encoding="utf-8")
158
+
159
+        with pytest.raises(
160
+            PendingPreferencePlanError, match="staged preference plan is not valid JSON"
161
+        ):
162
+            load_pending_plan(store)
163
+
164
+
165
+class TestPendingPayloadHelpers:
166
+    def test_section_from_payload_validates_tags_and_optional_types(self) -> None:
167
+        with pytest.raises(TypeError, match="expected object, got list"):
168
+            _section_from_payload([])
169
+
170
+        with pytest.raises(TypeError, match="tags must be an object"):
171
+            _section_from_payload({"type": "preference", "content": "x", "tags": []})
172
+
173
+        with pytest.raises(TypeError, match="tags keys and values must be strings"):
174
+            _section_from_payload({"type": "preference", "content": "x", "tags": {"topic": 1}})
175
+
176
+        with pytest.raises(TypeError, match="expected float or null"):
177
+            _section_from_payload(
178
+                {"type": "preference", "content": "x", "judge_score_chosen": True}
179
+            )
180
+
181
+        with pytest.raises(TypeError, match="expected int or null"):
182
+            _section_from_payload({"type": "preference", "content": "x", "mined_run_id": True})
183
+
184
+    def test_optional_helpers_accept_none_and_reject_wrong_types(self) -> None:
185
+        assert _optional_str(None) is None
186
+        assert _optional_str("ok") == "ok"
187
+        assert _optional_float(None) is None
188
+        assert _optional_float(1) == 1.0
189
+        assert _optional_int(None) is None
190
+        assert _optional_int(7) == 7
191
+
192
+        with pytest.raises(TypeError, match="expected string or null"):
193
+            _optional_str(7)
194
+
195
+        with pytest.raises(TypeError, match="expected float or null"):
196
+            _optional_float(True)
197
+
198
+        with pytest.raises(TypeError, match="expected int or null"):
199
+            _optional_int(True)
tests/unit/preference/test_sway_bridge.pyadded
262 lines changed — click to load
@@ -0,0 +1,262 @@
1
+"""Direct helper coverage for sway-backed preference judge wiring."""
2
+
3
+from __future__ import annotations
4
+
5
+import builtins
6
+import importlib
7
+import sys
8
+from pathlib import Path
9
+from types import SimpleNamespace
10
+from unittest.mock import patch
11
+
12
+import pytest
13
+
14
+from dlm.preference import JudgeUnavailableError
15
+from dlm.preference.judge import (
16
+    _build_sway_backend,
17
+    _import_sway_bridge,
18
+    _resolve_sway_trust_remote_code,
19
+)
20
+
21
+
22
+class FakeSwayError(Exception):
23
+    pass
24
+
25
+
26
+class FakeModelSpec:
27
+    def __init__(self, **kwargs: object) -> None:
28
+        self.kwargs = kwargs
29
+
30
+
31
+class FakeSysPath(list[str]):
32
+    def __init__(self) -> None:
33
+        super().__init__()
34
+        self.inserted: list[str] = []
35
+
36
+    def insert(self, index: int, value: str) -> None:  # type: ignore[override]
37
+        self.inserted.append(value)
38
+        super().insert(index, value)
39
+
40
+
41
+def test_build_sway_backend_requires_importable_bridge() -> None:
42
+    with (
43
+        patch("dlm.preference.judge._import_sway_bridge", side_effect=ImportError("missing")),
44
+        pytest.raises(JudgeUnavailableError, match="requires the sway bridge"),
45
+    ):
46
+        _build_sway_backend(Path("/tmp/example.dlm"))
47
+
48
+
49
+def test_build_sway_backend_wraps_sway_resolution_errors() -> None:
50
+    def resolve_dlm(_path: Path) -> object:
51
+        raise FakeSwayError("no store")
52
+
53
+    with (
54
+        patch(
55
+            "dlm.preference.judge._import_sway_bridge",
56
+            return_value=(resolve_dlm, object(), FakeModelSpec, FakeSwayError),
57
+        ),
58
+        pytest.raises(JudgeUnavailableError, match="could not resolve"),
59
+    ):
60
+        _build_sway_backend(Path("/tmp/example.dlm"))
61
+
62
+
63
+def test_build_sway_backend_wraps_generic_resolution_errors() -> None:
64
+    def resolve_dlm(_path: Path) -> object:
65
+        raise RuntimeError("boom")
66
+
67
+    with (
68
+        patch(
69
+            "dlm.preference.judge._import_sway_bridge",
70
+            return_value=(resolve_dlm, object(), FakeModelSpec, FakeSwayError),
71
+        ),
72
+        pytest.raises(JudgeUnavailableError, match="could not resolve"),
73
+    ):
74
+        _build_sway_backend(Path("/tmp/example.dlm"))
75
+
76
+
77
+def test_build_sway_backend_requires_trained_adapter() -> None:
78
+    handle = SimpleNamespace(adapter_path=None, base_model="base/model")
79
+
80
+    def resolve_dlm(_path: Path) -> object:
81
+        return handle
82
+
83
+    with (
84
+        patch(
85
+            "dlm.preference.judge._import_sway_bridge",
86
+            return_value=(resolve_dlm, object(), FakeModelSpec, FakeSwayError),
87
+        ),
88
+        pytest.raises(JudgeUnavailableError, match="requires a trained adapter"),
89
+    ):
90
+        _build_sway_backend(Path("/tmp/example.dlm"))
91
+
92
+
93
+def test_build_sway_backend_wraps_backend_load_errors() -> None:
94
+    handle = SimpleNamespace(adapter_path=Path("/tmp/adapter"), base_model="base/model")
95
+
96
+    def resolve_dlm(_path: Path) -> object:
97
+        return handle
98
+
99
+    def build_backend(_spec: FakeModelSpec, *, adapter_path: Path) -> object:
100
+        assert adapter_path == handle.adapter_path
101
+        raise RuntimeError("backend blew up")
102
+
103
+    with (
104
+        patch(
105
+            "dlm.preference.judge._import_sway_bridge",
106
+            return_value=(resolve_dlm, build_backend, FakeModelSpec, FakeSwayError),
107
+        ),
108
+        patch("dlm.preference.judge._resolve_sway_trust_remote_code", return_value=False),
109
+        pytest.raises(JudgeUnavailableError, match="could not load backend"),
110
+    ):
111
+        _build_sway_backend(Path("/tmp/example.dlm"))
112
+
113
+
114
+def test_build_sway_backend_builds_model_spec_with_trust_remote_code() -> None:
115
+    handle = SimpleNamespace(adapter_path=Path("/tmp/adapter"), base_model="base/model")
116
+    seen: dict[str, object] = {}
117
+
118
+    def resolve_dlm(_path: Path) -> object:
119
+        return handle
120
+
121
+    def build_backend(spec: FakeModelSpec, *, adapter_path: Path) -> object:
122
+        seen["spec"] = spec
123
+        seen["adapter_path"] = adapter_path
124
+        return "backend"
125
+
126
+    with (
127
+        patch(
128
+            "dlm.preference.judge._import_sway_bridge",
129
+            return_value=(resolve_dlm, build_backend, FakeModelSpec, FakeSwayError),
130
+        ),
131
+        patch("dlm.preference.judge._resolve_sway_trust_remote_code", return_value=True),
132
+    ):
133
+        backend = _build_sway_backend(Path("/tmp/example.dlm"))
134
+
135
+    assert backend == "backend"
136
+    spec = seen["spec"]
137
+    assert isinstance(spec, FakeModelSpec)
138
+    assert spec.kwargs == {
139
+        "kind": "hf",
140
+        "base": "base/model",
141
+        "adapter": handle.adapter_path,
142
+        "trust_remote_code": True,
143
+    }
144
+    assert seen["adapter_path"] == handle.adapter_path
145
+
146
+
147
+def test_import_sway_bridge_loads_modules_directly(monkeypatch: pytest.MonkeyPatch) -> None:
148
+    modules = {
149
+        "dlm_sway.backends": SimpleNamespace(build="build-backend"),
150
+        "dlm_sway.core.errors": SimpleNamespace(SwayError=FakeSwayError),
151
+        "dlm_sway.core.model": SimpleNamespace(ModelSpec=FakeModelSpec),
152
+        "dlm_sway.integrations.dlm.resolver": SimpleNamespace(resolve_dlm="resolve-dlm"),
153
+    }
154
+
155
+    def fake_import_module(name: str) -> object:
156
+        return modules[name]
157
+
158
+    monkeypatch.setattr(importlib, "import_module", fake_import_module)
159
+    resolve_dlm, build_backend, model_spec, sway_error = _import_sway_bridge()
160
+
161
+    assert resolve_dlm == "resolve-dlm"
162
+    assert build_backend == "build-backend"
163
+    assert model_spec is FakeModelSpec
164
+    assert sway_error is FakeSwayError
165
+
166
+
167
+def test_import_sway_bridge_falls_back_to_local_src_path(
168
+    monkeypatch: pytest.MonkeyPatch,
169
+) -> None:
170
+    modules = {
171
+        "dlm_sway.backends": SimpleNamespace(build="build-backend"),
172
+        "dlm_sway.core.errors": SimpleNamespace(SwayError=FakeSwayError),
173
+        "dlm_sway.core.model": SimpleNamespace(ModelSpec=FakeModelSpec),
174
+        "dlm_sway.integrations.dlm.resolver": SimpleNamespace(resolve_dlm="resolve-dlm"),
175
+    }
176
+    calls = {"count": 0}
177
+
178
+    def fake_import_module(name: str) -> object:
179
+        calls["count"] += 1
180
+        if calls["count"] == 1:
181
+            raise ImportError("first import fails")
182
+        return modules[name]
183
+
184
+    fake_sys_path = FakeSysPath()
185
+
186
+    monkeypatch.setattr(importlib, "import_module", fake_import_module)
187
+    monkeypatch.setattr(Path, "exists", lambda self: True)
188
+    monkeypatch.setattr(sys, "path", fake_sys_path)
189
+    resolve_dlm, build_backend, model_spec, sway_error = _import_sway_bridge()
190
+
191
+    assert resolve_dlm == "resolve-dlm"
192
+    assert build_backend == "build-backend"
193
+    assert model_spec is FakeModelSpec
194
+    assert sway_error is FakeSwayError
195
+    assert fake_sys_path.inserted
196
+    assert fake_sys_path.inserted[0].endswith("/sway/src")
197
+
198
+
199
+def test_resolve_sway_trust_remote_code_returns_false_when_imports_are_missing() -> None:
200
+    real_import = builtins.__import__
201
+
202
+    def fake_import(name: str, *args: object, **kwargs: object):
203
+        if name in {"dlm.base_models", "dlm.doc.parser"}:
204
+            raise ImportError("missing")
205
+        return real_import(name, *args, **kwargs)
206
+
207
+    with patch("builtins.__import__", side_effect=fake_import):
208
+        assert _resolve_sway_trust_remote_code(Path("/tmp/example.dlm")) is False
209
+
210
+
211
+def test_resolve_sway_trust_remote_code_handles_parse_and_resolve_failures() -> None:
212
+    fake_doc_parser = SimpleNamespace(
213
+        parse_file=lambda _path: (_ for _ in ()).throw(RuntimeError("bad"))
214
+    )
215
+    fake_base_models = SimpleNamespace(resolve=lambda *_args, **_kwargs: object())
216
+
217
+    with patch.dict(
218
+        "sys.modules",
219
+        {"dlm.doc.parser": fake_doc_parser, "dlm.base_models": fake_base_models},
220
+    ):
221
+        assert _resolve_sway_trust_remote_code(Path("/tmp/example.dlm")) is False
222
+
223
+    parsed = SimpleNamespace(frontmatter=SimpleNamespace(base_model="custom-base"))
224
+    fake_doc_parser = SimpleNamespace(parse_file=lambda _path: parsed)
225
+    fake_base_models = SimpleNamespace(
226
+        resolve=lambda *_args, **_kwargs: (_ for _ in ()).throw(RuntimeError("no base"))
227
+    )
228
+
229
+    with patch.dict(
230
+        "sys.modules",
231
+        {"dlm.doc.parser": fake_doc_parser, "dlm.base_models": fake_base_models},
232
+    ):
233
+        assert _resolve_sway_trust_remote_code(Path("/tmp/example.dlm")) is False
234
+
235
+
236
+@pytest.mark.parametrize("base_model", ["", "hf:org/model"])
237
+def test_resolve_sway_trust_remote_code_short_circuits_for_non_registry_models(
238
+    base_model: str,
239
+) -> None:
240
+    parsed = SimpleNamespace(frontmatter=SimpleNamespace(base_model=base_model))
241
+    fake_doc_parser = SimpleNamespace(parse_file=lambda _path: parsed)
242
+    fake_base_models = SimpleNamespace(resolve=lambda *_args, **_kwargs: object())
243
+
244
+    with patch.dict(
245
+        "sys.modules",
246
+        {"dlm.doc.parser": fake_doc_parser, "dlm.base_models": fake_base_models},
247
+    ):
248
+        assert _resolve_sway_trust_remote_code(Path("/tmp/example.dlm")) is False
249
+
250
+
251
+def test_resolve_sway_trust_remote_code_returns_spec_flag() -> None:
252
+    parsed = SimpleNamespace(frontmatter=SimpleNamespace(base_model="qwen3-1.7b"))
253
+    fake_doc_parser = SimpleNamespace(parse_file=lambda _path: parsed)
254
+    fake_base_models = SimpleNamespace(
255
+        resolve=lambda *_args, **_kwargs: SimpleNamespace(trust_remote_code=True)
256
+    )
257
+
258
+    with patch.dict(
259
+        "sys.modules",
260
+        {"dlm.doc.parser": fake_doc_parser, "dlm.base_models": fake_base_models},
261
+    ):
262
+        assert _resolve_sway_trust_remote_code(Path("/tmp/example.dlm")) is True
tests/unit/repl/test_app_helpers.pyadded
10 lines changed — click to load
@@ -0,0 +1,10 @@
1
+"""Direct coverage for small non-interactive REPL helpers."""
2
+
3
+from __future__ import annotations
4
+
5
+from dlm.repl.app import _format_prompt
6
+
7
+
8
+def test_format_prompt_handles_empty_and_existing_history() -> None:
9
+    assert _format_prompt([]) == "> "
10
+    assert _format_prompt([object(), object(), object(), object()]) == "[2] > "
tests/unit/repl/test_commands.pymodified
19 lines changed — click to load
@@ -7,7 +7,7 @@ from unittest.mock import MagicMock
77
 
88
 import pytest
99
 
10
-from dlm.repl.commands import Action, is_command, parse_and_dispatch
10
+from dlm.repl.commands import Action, _truncate, is_command, parse_and_dispatch
1111
 from dlm.repl.errors import BadCommandArgumentError, UnknownCommandError
1212
 from dlm.repl.session import ReplSession
1313
 
@@ -166,6 +166,11 @@ class TestHelp:
166166
             assert cmd in result.message
167167
 
168168
 
169
+class TestHelpers:
170
+    def test_truncate_adds_ellipsis_for_long_lines(self) -> None:
171
+        assert _truncate("x" * 20, 10) == "x" * 9 + "…"
172
+
173
+
169174
 class TestUnknownCommand:
170175
     def test_unknown_slash_raises(self) -> None:
171176
         with pytest.raises(UnknownCommandError, match="/bogus"):
tests/unit/repl/test_streaming.pyadded
64 lines changed — click to load
@@ -0,0 +1,64 @@
1
+"""Direct unit coverage for REPL streaming helpers."""
2
+
3
+from __future__ import annotations
4
+
5
+from types import SimpleNamespace
6
+from unittest.mock import patch
7
+
8
+from dlm.repl.streaming import CaptureStreamer, build_streamer, concatenate_tokens, should_stream
9
+
10
+
11
+def test_should_stream_tracks_stdout_tty_state() -> None:
12
+    with patch("sys.stdout", new=SimpleNamespace(isatty=lambda: True)):
13
+        assert should_stream() is True
14
+    with patch("sys.stdout", new=SimpleNamespace(isatty=lambda: False)):
15
+        assert should_stream() is False
16
+
17
+
18
+def test_should_stream_handles_broken_stdout() -> None:
19
+    class MissingIsAtty:
20
+        pass
21
+
22
+    class RaisesValueError:
23
+        @staticmethod
24
+        def isatty() -> bool:
25
+            raise ValueError("closed")
26
+
27
+    with patch("sys.stdout", new=MissingIsAtty()):
28
+        assert should_stream() is False
29
+    with patch("sys.stdout", new=RaisesValueError()):
30
+        assert should_stream() is False
31
+
32
+
33
+def test_capture_streamer_is_noop_and_keeps_text_buffer() -> None:
34
+    streamer = CaptureStreamer()
35
+    streamer.put(["ignored"])
36
+    streamer.end()
37
+    assert streamer.text == ""
38
+
39
+
40
+def test_build_streamer_returns_capture_streamer_when_disabled() -> None:
41
+    assert isinstance(build_streamer(object(), stream_to_stdout=False), CaptureStreamer)
42
+
43
+
44
+def test_build_streamer_wraps_transformers_text_streamer() -> None:
45
+    calls: list[tuple[object, bool, bool]] = []
46
+
47
+    class FakeTextStreamer:
48
+        def __init__(
49
+            self, tokenizer: object, *, skip_prompt: bool, skip_special_tokens: bool
50
+        ) -> None:
51
+            calls.append((tokenizer, skip_prompt, skip_special_tokens))
52
+
53
+    fake_transformers = SimpleNamespace(TextStreamer=FakeTextStreamer)
54
+    tokenizer = object()
55
+
56
+    with patch.dict("sys.modules", {"transformers": fake_transformers}):
57
+        streamer = build_streamer(tokenizer, stream_to_stdout=True)
58
+
59
+    assert isinstance(streamer, FakeTextStreamer)
60
+    assert calls == [(tokenizer, True, True)]
61
+
62
+
63
+def test_concatenate_tokens_joins_token_pieces() -> None:
64
+    assert concatenate_tokens(["hello", " ", "world"]) == "hello world"
tests/unit/replay/test_corpus.pymodified
27 lines changed — click to load
@@ -8,6 +8,7 @@ from pathlib import Path
88
 
99
 import pytest
1010
 
11
+import dlm.replay.corpus as replay_corpus
1112
 from dlm.replay.corpus import _encode_frame, append_snapshot, iter_snapshots, read_chunk
1213
 from dlm.replay.errors import CorpusCorruptError
1314
 from dlm.replay.models import SectionSnapshot
@@ -89,6 +90,20 @@ class TestCorruption:
8990
         with pytest.raises(CorpusCorruptError):
9091
             read_chunk(corpus, byte_offset=0, length=len(frame))
9192
 
93
+    def test_cbor_value_error_is_wrapped(
94
+        self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
95
+    ) -> None:
96
+        frame = _encode_frame(_snap("a" * 16, "hello"))
97
+        corpus = tmp_path / "corpus.zst"
98
+        corpus.write_bytes(frame)
99
+
100
+        def _boom(_payload: bytes) -> object:
101
+            raise ValueError("bad semantic tag")
102
+
103
+        monkeypatch.setattr(replay_corpus.cbor2, "loads", _boom)
104
+        with pytest.raises(CorpusCorruptError, match="CBOR decode failed"):
105
+            read_chunk(corpus, byte_offset=0, length=len(frame))
106
+
92107
     def test_iter_short_read_raises(self, tmp_path: Path) -> None:
93108
         """iter_snapshots also guards against truncated entries."""
94109
         from dlm.replay.models import IndexEntry
tests/unit/replay/test_sampler.pymodified
36 lines changed — click to load
@@ -9,7 +9,7 @@ import pytest
99
 
1010
 from dlm.replay.errors import SamplerError
1111
 from dlm.replay.models import IndexEntry
12
-from dlm.replay.sampler import sample
12
+from dlm.replay.sampler import _weighted_reservoir, sample
1313
 
1414
 _NOW = datetime(2026, 4, 18)
1515
 
@@ -113,3 +113,28 @@ class TestStableOrdering:
113113
         p1 = sample(entries_a, k=5, now=_NOW, rng=random.Random(7), scheme="uniform")
114114
         p2 = sample(entries_b, k=5, now=_NOW, rng=random.Random(7), scheme="uniform")
115115
         assert [e.section_id for e in p1] == [e.section_id for e in p2]
116
+
117
+
118
+class TestReservoirEdgeCases:
119
+    def test_zero_random_draw_retries_and_falls_back_to_tiny_positive(self) -> None:
120
+        entries = _entries(2)
121
+
122
+        class _ZeroThenHalfRng:
123
+            def __init__(self) -> None:
124
+                self._values = iter([0.0, 0.0, 0.5, 0.5])
125
+
126
+            def random(self) -> float:
127
+                return next(self._values)
128
+
129
+        picked = sample(entries, k=1, now=_NOW, rng=_ZeroThenHalfRng(), scheme="uniform")
130
+        assert len(picked) == 1
131
+
132
+    def test_nonpositive_weight_entries_are_skipped(self) -> None:
133
+        entries = _entries(3)
134
+        picked = _weighted_reservoir(
135
+            entries,
136
+            weights=[1.0, 0.0, -1.0],
137
+            k=3,
138
+            rng=random.Random(0),
139
+        )
140
+        assert [entry.section_id for entry in picked] == [entries[0].section_id]
tests/unit/share/test_hf_sink.pymodified
27 lines changed — click to load
@@ -152,6 +152,27 @@ class TestPushHf:
152152
         with pytest.raises(SinkError, match="upload failed"):
153153
             push_hf(pack, "user/myadapter")
154154
 
155
+    def test_readme_upload_failure_translates_to_sink_error(
156
+        self,
157
+        pack: Path,
158
+        patched_hub: dict[str, list[dict[str, object]]],
159
+        monkeypatch: pytest.MonkeyPatch,
160
+    ) -> None:
161
+        import huggingface_hub
162
+
163
+        calls = {"count": 0}
164
+
165
+        def _boom_on_second_upload(**kwargs: object) -> str:
166
+            calls["count"] += 1
167
+            if calls["count"] == 2:
168
+                raise _FakeHfHubHTTPError("readme denied")
169
+            return f"https://huggingface.co/{kwargs['repo_id']}/blob/main/{kwargs['path_in_repo']}"
170
+
171
+        monkeypatch.setattr(huggingface_hub, "upload_file", _boom_on_second_upload, raising=False)
172
+
173
+        with pytest.raises(SinkError, match="README upload failed"):
174
+            push_hf(pack, "user/myadapter")
175
+
155176
     def test_progress_fires_with_full_size(
156177
         self, pack: Path, patched_hub: dict[str, list[dict[str, object]]]
157178
     ) -> None:
tests/unit/share/test_peer_runtime.pyadded
388 lines changed — click to load
@@ -0,0 +1,388 @@
1
+"""Runtime coverage for the peer share transport."""
2
+
3
+from __future__ import annotations
4
+
5
+import importlib
6
+import socket
7
+from io import BytesIO
8
+from pathlib import Path
9
+from types import SimpleNamespace
10
+
11
+import pytest
12
+
13
+from dlm.share.errors import PeerAuthError, RateLimitError
14
+from dlm.share.peer import (
15
+    RateLimiter,
16
+    ServeHandle,
17
+    ServeOptions,
18
+    _detect_lan_ip,
19
+    _log_connection,
20
+    build_handler,
21
+    new_session,
22
+    pull_peer,
23
+    serve,
24
+)
25
+
26
+peer_mod = importlib.import_module("dlm.share.peer")
27
+
28
+
29
+def _build_test_handler(
30
+    tmp_path: Path,
31
+    *,
32
+    path: str,
33
+) -> tuple[type[object], object, list[tuple[str, str, str, str]], RateLimiter, Path]:
34
+    session = new_session("01HZPEER")
35
+    pack_path = tmp_path / "bundle.dlm.pack"
36
+    pack_path.write_bytes(b"peer-pack")
37
+    rate_limiter = RateLimiter(max_concurrency=4, rate_limit_per_min=30)
38
+    logs: list[tuple[str, str, str, str]] = []
39
+
40
+    handler_cls = build_handler(session, pack_path, rate_limiter)
41
+    handler = object.__new__(handler_cls)
42
+    handler.path = path
43
+    handler.client_address = ("127.0.0.1", 7337)
44
+    handler.send_error = lambda code, message: logs.append(("error", str(code), message, ""))  # type: ignore[attr-defined]
45
+    handler._stream_pack = lambda path: logs.append(("stream", str(path), "", ""))  # type: ignore[attr-defined]
46
+    return handler_cls, handler, logs, rate_limiter, pack_path
47
+
48
+
49
+class TestPeerHandler:
50
+    def test_log_message_is_silent(self, tmp_path: Path) -> None:
51
+        handler_cls, handler, _logs, _rate_limiter, _pack_path = _build_test_handler(
52
+            tmp_path, path="/ignored"
53
+        )
54
+        assert handler_cls.log_message(handler, "%s", "ignored") is None
55
+
56
+    def test_handler_rejects_unknown_dlm_id(
57
+        self,
58
+        tmp_path: Path,
59
+        monkeypatch: pytest.MonkeyPatch,
60
+    ) -> None:
61
+        handler_cls, handler, events, _rate_limiter, _pack_path = _build_test_handler(
62
+            tmp_path, path="/wrong?token=abc"
63
+        )
64
+        request_logs: list[tuple[str, str, str, str]] = []
65
+        monkeypatch.setattr(
66
+            peer_mod,
67
+            "_log_connection",
68
+            lambda ip, method, path, status: request_logs.append((ip, method, path, status)),
69
+        )
70
+
71
+        handler_cls.do_GET(handler)
72
+
73
+        assert events == [("error", "404", "unknown dlm_id", "")]
74
+        assert request_logs == [
75
+            ("127.0.0.1", "GET", "/wrong", "start"),
76
+            ("127.0.0.1", "GET", "/wrong", "404 unknown dlm_id"),
77
+        ]
78
+
79
+    def test_handler_rejects_missing_token(
80
+        self,
81
+        tmp_path: Path,
82
+        monkeypatch: pytest.MonkeyPatch,
83
+    ) -> None:
84
+        handler_cls, handler, events, _rate_limiter, _pack_path = _build_test_handler(
85
+            tmp_path, path="/01HZPEER"
86
+        )
87
+        request_logs: list[tuple[str, str, str, str]] = []
88
+        monkeypatch.setattr(
89
+            peer_mod,
90
+            "_log_connection",
91
+            lambda ip, method, path, status: request_logs.append((ip, method, path, status)),
92
+        )
93
+
94
+        handler_cls.do_GET(handler)
95
+
96
+        assert events == [("error", "401", "missing token", "")]
97
+        assert request_logs == [
98
+            ("127.0.0.1", "GET", "/01HZPEER", "start"),
99
+            ("127.0.0.1", "GET", "/01HZPEER", "401 missing token"),
100
+        ]
101
+
102
+    def test_handler_rejects_bad_token(
103
+        self,
104
+        tmp_path: Path,
105
+        monkeypatch: pytest.MonkeyPatch,
106
+    ) -> None:
107
+        handler_cls, handler, events, _rate_limiter, _pack_path = _build_test_handler(
108
+            tmp_path, path="/01HZPEER?token=bad"
109
+        )
110
+        request_logs: list[tuple[str, str, str, str]] = []
111
+        monkeypatch.setattr(
112
+            peer_mod,
113
+            "_log_connection",
114
+            lambda ip, method, path, status: request_logs.append((ip, method, path, status)),
115
+        )
116
+        monkeypatch.setattr(
117
+            peer_mod.PeerSession,
118
+            "verify_token",
119
+            lambda self, token: (_ for _ in ()).throw(PeerAuthError("bad token")),
120
+        )
121
+
122
+        handler_cls.do_GET(handler)
123
+
124
+        assert events == [("error", "403", "token rejected", "")]
125
+        assert request_logs == [
126
+            ("127.0.0.1", "GET", "/01HZPEER", "start"),
127
+            ("127.0.0.1", "GET", "/01HZPEER", "403 bad token"),
128
+        ]
129
+
130
+    def test_handler_rejects_rate_limited(
131
+        self,
132
+        tmp_path: Path,
133
+        monkeypatch: pytest.MonkeyPatch,
134
+    ) -> None:
135
+        handler_cls, handler, events, rate_limiter, _pack_path = _build_test_handler(
136
+            tmp_path, path="/01HZPEER?token=good"
137
+        )
138
+        request_logs: list[tuple[str, str, str, str]] = []
139
+        monkeypatch.setattr(
140
+            peer_mod,
141
+            "_log_connection",
142
+            lambda ip, method, path, status: request_logs.append((ip, method, path, status)),
143
+        )
144
+        monkeypatch.setattr(peer_mod.PeerSession, "verify_token", lambda self, token: None)
145
+        monkeypatch.setattr(
146
+            rate_limiter,
147
+            "check_and_acquire",
148
+            lambda: (_ for _ in ()).throw(RateLimitError("too many")),
149
+        )
150
+
151
+        handler_cls.do_GET(handler)
152
+
153
+        assert events == [("error", "429", "rate limited", "")]
154
+        assert request_logs == [
155
+            ("127.0.0.1", "GET", "/01HZPEER", "start"),
156
+            ("127.0.0.1", "GET", "/01HZPEER", "429 too many"),
157
+        ]
158
+
159
+    def test_handler_streams_pack_and_releases_limiter(
160
+        self,
161
+        tmp_path: Path,
162
+        monkeypatch: pytest.MonkeyPatch,
163
+    ) -> None:
164
+        handler_cls, handler, events, rate_limiter, pack_path = _build_test_handler(
165
+            tmp_path, path="/01HZPEER?token=good"
166
+        )
167
+        request_logs: list[tuple[str, str, str, str]] = []
168
+        monkeypatch.setattr(
169
+            peer_mod,
170
+            "_log_connection",
171
+            lambda ip, method, path, status: request_logs.append((ip, method, path, status)),
172
+        )
173
+        monkeypatch.setattr(peer_mod.PeerSession, "verify_token", lambda self, token: None)
174
+
175
+        handler_cls.do_GET(handler)
176
+
177
+        assert events == [("stream", str(pack_path), "", "")]
178
+        assert rate_limiter.active == 0
179
+        assert request_logs == [
180
+            ("127.0.0.1", "GET", "/01HZPEER", "start"),
181
+            ("127.0.0.1", "GET", "/01HZPEER", "200 complete"),
182
+        ]
183
+
184
+    def test_stream_pack_writes_headers_and_body(self, tmp_path: Path) -> None:
185
+        handler_cls, handler, _events, _rate_limiter, pack_path = _build_test_handler(
186
+            tmp_path, path="/ignored"
187
+        )
188
+        responses: list[tuple[str, str]] = []
189
+        body = BytesIO()
190
+        handler.wfile = body
191
+        handler.send_response = lambda status: responses.append(("status", str(status)))  # type: ignore[attr-defined]
192
+        handler.send_header = lambda name, value: responses.append((name, value))  # type: ignore[attr-defined]
193
+        handler.end_headers = lambda: responses.append(("end", ""))  # type: ignore[attr-defined]
194
+
195
+        handler_cls._stream_pack(handler, pack_path)
196
+
197
+        assert responses == [
198
+            ("status", "200"),
199
+            ("Content-Type", "application/octet-stream"),
200
+            ("Content-Length", str(len(b"peer-pack"))),
201
+            ("end", ""),
202
+        ]
203
+        assert body.getvalue() == b"peer-pack"
204
+
205
+
206
+class TestPeerHelpers:
207
+    def test_log_connection_emits_metadata_only(self, caplog: pytest.LogCaptureFixture) -> None:
208
+        caplog.set_level("INFO")
209
+
210
+        _log_connection("127.0.0.1", "GET", "/01HZPEER", "200 complete")
211
+
212
+        assert "peer: GET /01HZPEER 200 complete from 127.0.0.1" in caplog.text
213
+
214
+    def test_pull_peer_reuses_url_sink(
215
+        self,
216
+        tmp_path: Path,
217
+        monkeypatch: pytest.MonkeyPatch,
218
+    ) -> None:
219
+        import dlm.share.url_sink as url_sink
220
+
221
+        out_path = tmp_path / "incoming.dlm.pack"
222
+        seen: dict[str, object] = {}
223
+
224
+        def _fake_pull_url(url: str, actual_out: Path, *, progress: object | None = None) -> int:
225
+            seen["url"] = url
226
+            seen["out"] = actual_out
227
+            seen["progress"] = progress
228
+            return 42
229
+
230
+        monkeypatch.setattr(url_sink, "pull_url", _fake_pull_url)
231
+
232
+        result = pull_peer("host:7337/01HZPEER?token=abc", out_path, progress=None)
233
+
234
+        assert result == 42
235
+        assert seen == {
236
+            "url": "http://host:7337/01HZPEER?token=abc",
237
+            "out": out_path,
238
+            "progress": None,
239
+        }
240
+
241
+
242
+class TestServeHandle:
243
+    def test_peer_url_uses_bind_host_for_loopback(self) -> None:
244
+        handle = ServeHandle(
245
+            session=SimpleNamespace(dlm_id="01HZPEER"),
246
+            bind_host="127.0.0.1",
247
+            port=7337,
248
+            token="abc",
249
+            _server=SimpleNamespace(),
250
+        )
251
+
252
+        assert handle.peer_url == "peer://127.0.0.1:7337/01HZPEER?token=abc"
253
+
254
+    def test_peer_url_detects_lan_ip_for_public_bind(self, monkeypatch: pytest.MonkeyPatch) -> None:
255
+        handle = ServeHandle(
256
+            session=SimpleNamespace(dlm_id="01HZPEER"),
257
+            bind_host="0.0.0.0",
258
+            port=7337,
259
+            token="abc",
260
+            _server=SimpleNamespace(),
261
+        )
262
+        monkeypatch.setattr(peer_mod, "_detect_lan_ip", lambda: "192.168.1.9")
263
+
264
+        assert handle.peer_url == "peer://192.168.1.9:7337/01HZPEER?token=abc"
265
+
266
+    def test_wait_shutdown_stops_server_cleanly(self) -> None:
267
+        calls: list[str] = []
268
+        server = SimpleNamespace(
269
+            serve_forever=lambda: calls.append("serve_forever"),
270
+            shutdown=lambda: calls.append("shutdown"),
271
+            server_close=lambda: calls.append("server_close"),
272
+        )
273
+        handle = ServeHandle(
274
+            session=SimpleNamespace(dlm_id="01HZPEER"),
275
+            bind_host="127.0.0.1",
276
+            port=7337,
277
+            token="abc",
278
+            _server=server,
279
+        )
280
+
281
+        handle.wait_shutdown()
282
+
283
+        assert calls == ["serve_forever", "shutdown", "server_close"]
284
+
285
+    def test_wait_shutdown_handles_keyboard_interrupt(
286
+        self, caplog: pytest.LogCaptureFixture
287
+    ) -> None:
288
+        calls: list[str] = []
289
+
290
+        def _serve_forever() -> None:
291
+            calls.append("serve_forever")
292
+            raise KeyboardInterrupt
293
+
294
+        server = SimpleNamespace(
295
+            serve_forever=_serve_forever,
296
+            shutdown=lambda: calls.append("shutdown"),
297
+            server_close=lambda: calls.append("server_close"),
298
+        )
299
+        handle = ServeHandle(
300
+            session=SimpleNamespace(dlm_id="01HZPEER"),
301
+            bind_host="127.0.0.1",
302
+            port=7337,
303
+            token="abc",
304
+            _server=server,
305
+        )
306
+        caplog.set_level("INFO")
307
+
308
+        handle.wait_shutdown()
309
+
310
+        assert calls == ["serve_forever", "shutdown", "server_close"]
311
+        assert "shutdown requested" in caplog.text
312
+
313
+
314
+class TestServe:
315
+    def test_serve_builds_handle(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
316
+        pack_path = tmp_path / "bundle.dlm.pack"
317
+        pack_path.write_bytes(b"peer-pack")
318
+        handler_cls = type("FakeHandler", (), {})
319
+        server_calls: dict[str, object] = {}
320
+
321
+        class FakeSession:
322
+            dlm_id = "01HZPEER"
323
+
324
+            def issue_token(self) -> str:
325
+                return "issued-token"
326
+
327
+        class FakeServer:
328
+            def __init__(self, address: tuple[str, int], handler: type[object]) -> None:
329
+                server_calls["address"] = address
330
+                server_calls["handler"] = handler
331
+
332
+        monkeypatch.setattr(
333
+            peer_mod, "new_session", lambda dlm_id, token_ttl_seconds: FakeSession()
334
+        )
335
+        monkeypatch.setattr(
336
+            peer_mod, "build_handler", lambda session, actual_pack, limiter: handler_cls
337
+        )
338
+        monkeypatch.setattr(peer_mod, "resolve_bind", lambda opts: "127.0.0.1")
339
+        monkeypatch.setattr(peer_mod.http.server, "ThreadingHTTPServer", FakeServer)
340
+
341
+        handle = serve("01HZPEER", pack_path, ServeOptions(port=8123))
342
+
343
+        assert handle.session.dlm_id == "01HZPEER"
344
+        assert handle.bind_host == "127.0.0.1"
345
+        assert handle.port == 8123
346
+        assert handle.token == "issued-token"
347
+        assert server_calls == {
348
+            "address": ("127.0.0.1", 8123),
349
+            "handler": handler_cls,
350
+        }
351
+
352
+
353
+class TestDetectLanIp:
354
+    def test_detect_lan_ip_returns_socket_address(self, monkeypatch: pytest.MonkeyPatch) -> None:
355
+        class FakeSocket:
356
+            def settimeout(self, value: float) -> None:
357
+                assert value == 0.1
358
+
359
+            def connect(self, target: tuple[str, int]) -> None:
360
+                assert target == ("10.254.254.254", 1)
361
+
362
+            def getsockname(self) -> tuple[str, int]:
363
+                return ("192.168.1.7", 9999)
364
+
365
+            def __enter__(self) -> FakeSocket:
366
+                return self
367
+
368
+            def __exit__(self, exc_type: object, exc: object, tb: object) -> None:
369
+                return None
370
+
371
+        monkeypatch.setattr(socket, "socket", lambda *args, **kwargs: FakeSocket())
372
+
373
+        assert _detect_lan_ip() == "192.168.1.7"
374
+
375
+    def test_detect_lan_ip_returns_placeholder_on_error(
376
+        self,
377
+        monkeypatch: pytest.MonkeyPatch,
378
+    ) -> None:
379
+        class FakeSocket:
380
+            def __enter__(self) -> FakeSocket:
381
+                raise OSError("no route")
382
+
383
+            def __exit__(self, exc_type: object, exc: object, tb: object) -> None:
384
+                return None
385
+
386
+        monkeypatch.setattr(socket, "socket", lambda *args, **kwargs: FakeSocket())
387
+
388
+        assert _detect_lan_ip() == "<lan-ip>"
tests/unit/share/test_peer_tokens.pymodified
59 lines changed — click to load
@@ -2,6 +2,7 @@
22
 
33
 from __future__ import annotations
44
 
5
+import base64
56
 import time
67
 
78
 import pytest
@@ -46,6 +47,36 @@ class TestTokenRoundTrip:
4647
         with pytest.raises(PeerAuthError):
4748
             s.verify_token("AAAA")
4849
 
50
+    def test_trailing_bytes_refused(self) -> None:
51
+        s = new_session("01HZTEST")
52
+        nonce = b"x" * 12
53
+        expiry_iso = "2099-01-01T00:00:00+00:00"
54
+        signature = s._sign(s.dlm_id, expiry_iso, nonce)
55
+        payload = (
56
+            nonce
57
+            + len(expiry_iso).to_bytes(2, "big")
58
+            + expiry_iso.encode("ascii")
59
+            + signature
60
+            + b"!"
61
+        )
62
+        token = base64.urlsafe_b64encode(payload).decode("ascii").rstrip("=")
63
+
64
+        with pytest.raises(PeerAuthError, match="trailing bytes"):
65
+            s.verify_token(token)
66
+
67
+    def test_malformed_expiry_refused(self) -> None:
68
+        s = new_session("01HZTEST")
69
+        nonce = b"y" * 12
70
+        expiry_iso = "not-a-date"
71
+        signature = s._sign(s.dlm_id, expiry_iso, nonce)
72
+        payload = (
73
+            nonce + len(expiry_iso).to_bytes(2, "big") + expiry_iso.encode("ascii") + signature
74
+        )
75
+        token = base64.urlsafe_b64encode(payload).decode("ascii").rstrip("=")
76
+
77
+        with pytest.raises(PeerAuthError, match="malformed expiry"):
78
+            s.verify_token(token)
79
+
4980
     def test_expired_token(self) -> None:
5081
         # TTL of 0 — any read after issuance is past-expiry.
5182
         s = new_session("01HZTEST", token_ttl_seconds=0)
@@ -93,6 +124,16 @@ class TestRateLimiter:
93124
         with pytest.raises(RateLimitError, match="req/min"):
94125
             rl.check_and_acquire()
95126
 
127
+    def test_prunes_requests_older_than_one_minute(self, monkeypatch: pytest.MonkeyPatch) -> None:
128
+        rl = RateLimiter(max_concurrency=10, rate_limit_per_min=10)
129
+        rl.requests.extend([10.0, 190.0])
130
+        monkeypatch.setattr("dlm.share.peer.time.monotonic", lambda: 200.0)
131
+
132
+        rl.check_and_acquire()
133
+
134
+        assert list(rl.requests) == [190.0, 200.0]
135
+        assert rl.active == 1
136
+
96137
     def test_release_idempotent_on_zero(self) -> None:
97138
         # Release more than was acquired — shouldn't go negative.
98139
         rl = RateLimiter()
tests/unit/share/test_provenance.pymodified
84 lines changed — click to load
@@ -11,6 +11,7 @@ import pytest
1111
 from dlm.share.provenance import (
1212
     Provenance,
1313
     ProvenanceChainBroken,
14
+    ProvenanceError,
1415
     ProvenanceSchemaError,
1516
     ProvenanceVerifyResult,
1617
     UnknownSignerError,
@@ -178,6 +179,22 @@ class TestTrustedKeyRegistry:
178179
         second = record_trusted_key(_SAMPLE_PUBKEY, trusted_keys_dir=tmp_path)
179180
         assert first == second
180181
 
182
+    def test_record_refuses_to_overwrite_different_key_contents(self, tmp_path: Path) -> None:
183
+        target = record_trusted_key(_SAMPLE_PUBKEY, trusted_keys_dir=tmp_path, label="alice")
184
+
185
+        with pytest.MonkeyPatch.context() as mp:
186
+            mp.setattr(
187
+                "dlm.share.provenance.pubkey_fingerprint",
188
+                lambda _key: target.stem.removeprefix("alice-"),
189
+            )
190
+            with pytest.raises(ProvenanceError, match="refusing to overwrite"):
191
+                record_trusted_key(
192
+                    _SAMPLE_PUBKEY + "\nDIFFERENT",
193
+                    trusted_keys_dir=tmp_path,
194
+                    label="alice",
195
+                )
196
+        assert target.is_file()
197
+
181198
     def test_find_matching_returns_path(self, tmp_path: Path) -> None:
182199
         record_trusted_key(_SAMPLE_PUBKEY, trusted_keys_dir=tmp_path)
183200
         found = find_matching_trusted_key(_SAMPLE_PUBKEY, trusted_keys_dir=tmp_path)
@@ -193,6 +210,26 @@ class TestTrustedKeyRegistry:
193210
         )
194211
         assert found is None
195212
 
213
+    def test_find_matching_skips_unreadable_pubkey_files(
214
+        self,
215
+        tmp_path: Path,
216
+        monkeypatch: pytest.MonkeyPatch,
217
+    ) -> None:
218
+        good = record_trusted_key(_SAMPLE_PUBKEY, trusted_keys_dir=tmp_path)
219
+        bad = tmp_path / "000-bad.pub"
220
+        bad.write_text("broken", encoding="utf-8")
221
+        path_type = type(bad)
222
+        real_read_text = path_type.read_text
223
+
224
+        def _maybe_broken(self: Path, *args: object, **kwargs: object) -> str:
225
+            if self == bad:
226
+                raise OSError("boom")
227
+            return real_read_text(self, *args, **kwargs)
228
+
229
+        monkeypatch.setattr(path_type, "read_text", _maybe_broken)
230
+
231
+        assert find_matching_trusted_key(_SAMPLE_PUBKEY, trusted_keys_dir=tmp_path) == good
232
+
196233
 
197234
 class TestVerifyProvenance:
198235
     def _stub_verifier_accepts(self, chain: bytes, signature: str, pubkey_path: Path) -> None:
@@ -268,3 +305,29 @@ class TestChainConsistency:
268305
         prov = _sample_provenance(adapter_sha256="a" * 64)
269306
         with pytest.raises(ProvenanceChainBroken, match="mismatch"):
270307
             recompute_chain_consistency(prov, adapter_sha256="b" * 64)
308
+
309
+
310
+class TestDefaultSignatureVerifier:
311
+    def test_default_signature_verifier_writes_temp_files_and_calls_minisign(
312
+        self,
313
+        monkeypatch: pytest.MonkeyPatch,
314
+        tmp_path: Path,
315
+    ) -> None:
316
+        from dlm.share.provenance import _default_signature_verifier
317
+
318
+        seen: dict[str, object] = {}
319
+
320
+        def _fake_minisign_verify(payload: Path, sig: Path, pubkey: Path) -> None:
321
+            seen["payload"] = payload.read_bytes()
322
+            seen["signature"] = sig.read_text(encoding="utf-8")
323
+            seen["pubkey"] = pubkey
324
+
325
+        monkeypatch.setattr("dlm.share.signing._minisign_verify", _fake_minisign_verify)
326
+
327
+        pubkey = tmp_path / "key.pub"
328
+        pubkey.write_text("pub", encoding="utf-8")
329
+        _default_signature_verifier(b"chain-bytes", "signature-block", pubkey)
330
+
331
+        assert seen["payload"] == b"chain-bytes"
332
+        assert seen["signature"] == "signature-block"
333
+        assert seen["pubkey"] == pubkey
tests/unit/share/test_pull.pyadded
393 lines changed — click to load
@@ -0,0 +1,393 @@
1
+"""Unit coverage for the share pull orchestrator."""
2
+
3
+from __future__ import annotations
4
+
5
+import importlib
6
+import sys
7
+from pathlib import Path
8
+from types import ModuleType, SimpleNamespace
9
+from typing import cast
10
+
11
+import pytest
12
+
13
+from dlm.share.errors import ShareError, SinkError
14
+from dlm.share.pull import (
15
+    PullResult,
16
+    _dispatch_pull,
17
+    _log_verification,
18
+    _try_hf_sidecar,
19
+    _try_peer_sidecar,
20
+    _try_url_sidecar,
21
+    pull,
22
+)
23
+from dlm.share.signing import VerifyResult, VerifyStatus
24
+from dlm.share.sinks import SinkKind, SinkSpec
25
+
26
+pull_mod = importlib.import_module("dlm.share.pull")
27
+
28
+
29
+class TestPull:
30
+    def test_pull_dispatches_verifies_and_unpacks(
31
+        self,
32
+        tmp_path: Path,
33
+        monkeypatch: pytest.MonkeyPatch,
34
+    ) -> None:
35
+        source = "https://example.test/adapter.dlm.pack"
36
+        out_dir = tmp_path / "out"
37
+        home = tmp_path / "home"
38
+        progress = object()
39
+        spec = SinkSpec(kind=SinkKind.URL, target=source)
40
+        order: list[str] = []
41
+        verification = VerifyResult(status=VerifyStatus.VERIFIED, key_path=tmp_path / "trusted.pub")
42
+
43
+        monkeypatch.setattr(
44
+            pull_mod, "parse_source", lambda value: spec if value == source else None
45
+        )
46
+
47
+        def _fake_dispatch(
48
+            actual_spec: SinkSpec,
49
+            pack_path: Path,
50
+            sig_path: Path,
51
+            *,
52
+            progress: object | None,
53
+        ) -> int:
54
+            order.append("dispatch")
55
+            assert actual_spec == spec
56
+            assert pack_path.name == "incoming.dlm.pack"
57
+            assert sig_path.name == "incoming.dlm.pack.minisig"
58
+            assert progress is not None
59
+            pack_path.write_bytes(b"pack-bytes")
60
+            sig_path.write_text("signature", encoding="utf-8")
61
+            return 123
62
+
63
+        def _fake_verify(pack_path: Path, sig_path: Path) -> VerifyResult:
64
+            order.append("verify")
65
+            assert pack_path.read_bytes() == b"pack-bytes"
66
+            assert sig_path.read_text(encoding="utf-8") == "signature"
67
+            return verification
68
+
69
+        def _fake_unpack(
70
+            pack_path: Path,
71
+            *,
72
+            home: Path | None,
73
+            force: bool,
74
+            out_dir: Path,
75
+        ) -> SimpleNamespace:
76
+            order.append("unpack")
77
+            assert pack_path.read_bytes() == b"pack-bytes"
78
+            assert home == tmp_path / "home"
79
+            assert force is True
80
+            assert out_dir == tmp_path / "out"
81
+            return SimpleNamespace(
82
+                dlm_path=out_dir / "restored.dlm",
83
+                store_path=home / "store" / "01HZPULL",
84
+                dlm_id="01HZPULL",
85
+            )
86
+
87
+        monkeypatch.setattr(pull_mod, "_dispatch_pull", _fake_dispatch)
88
+        monkeypatch.setattr(pull_mod, "verify_signature", _fake_verify)
89
+        monkeypatch.setattr(pull_mod, "pack_unpack", _fake_unpack)
90
+
91
+        result = pull(
92
+            source,
93
+            out_dir=out_dir,
94
+            force=True,
95
+            home=home,
96
+            progress=cast("object", progress),
97
+        )
98
+
99
+        assert result == PullResult(
100
+            dlm_path=out_dir / "restored.dlm",
101
+            store_path=home / "store" / "01HZPULL",
102
+            dlm_id="01HZPULL",
103
+            source=source,
104
+            bytes_received=123,
105
+            verification=verification,
106
+        )
107
+        assert order == ["dispatch", "verify", "unpack"]
108
+
109
+
110
+class TestDispatchPull:
111
+    def test_dispatch_pull_hf_downloads_pack_and_sidecar(
112
+        self,
113
+        tmp_path: Path,
114
+        monkeypatch: pytest.MonkeyPatch,
115
+    ) -> None:
116
+        import dlm.share.hf_sink as hf_sink
117
+
118
+        pack_path = tmp_path / "pack.dlm.pack"
119
+        sig_path = tmp_path / "pack.dlm.pack.minisig"
120
+        progress = object()
121
+        seen: dict[str, object] = {}
122
+
123
+        def _fake_pull_hf(repo_id: str, out_path: Path, *, progress: object | None = None) -> int:
124
+            seen["repo_id"] = repo_id
125
+            seen["progress"] = progress
126
+            out_path.write_bytes(b"hf-pack")
127
+            return 7
128
+
129
+        monkeypatch.setattr(hf_sink, "pull_hf", _fake_pull_hf)
130
+        monkeypatch.setattr(
131
+            pull_mod,
132
+            "_try_hf_sidecar",
133
+            lambda repo_id, sidecar_path: seen.update(
134
+                {"sidecar_repo_id": repo_id, "sidecar_path": sidecar_path}
135
+            ),
136
+        )
137
+
138
+        bytes_received = _dispatch_pull(
139
+            SinkSpec(kind=SinkKind.HF, target="org/repo"),
140
+            pack_path,
141
+            sig_path,
142
+            progress=cast("object", progress),
143
+        )
144
+
145
+        assert bytes_received == 7
146
+        assert pack_path.read_bytes() == b"hf-pack"
147
+        assert seen == {
148
+            "repo_id": "org/repo",
149
+            "progress": progress,
150
+            "sidecar_repo_id": "org/repo",
151
+            "sidecar_path": sig_path,
152
+        }
153
+
154
+    def test_dispatch_pull_url_downloads_pack_and_sidecar(
155
+        self,
156
+        tmp_path: Path,
157
+        monkeypatch: pytest.MonkeyPatch,
158
+    ) -> None:
159
+        import dlm.share.url_sink as url_sink
160
+
161
+        pack_path = tmp_path / "pack.dlm.pack"
162
+        sig_path = tmp_path / "pack.dlm.pack.minisig"
163
+        seen: dict[str, object] = {}
164
+
165
+        def _fake_pull_url(url: str, out_path: Path, *, progress: object | None = None) -> int:
166
+            seen["url"] = url
167
+            seen["progress"] = progress
168
+            out_path.write_bytes(b"url-pack")
169
+            return 9
170
+
171
+        monkeypatch.setattr(url_sink, "pull_url", _fake_pull_url)
172
+        monkeypatch.setattr(
173
+            pull_mod,
174
+            "_try_url_sidecar",
175
+            lambda url, sidecar_path: seen.update(
176
+                {"sidecar_url": url, "sidecar_path": sidecar_path}
177
+            ),
178
+        )
179
+
180
+        bytes_received = _dispatch_pull(
181
+            SinkSpec(kind=SinkKind.URL, target="https://example.test/pack"),
182
+            pack_path,
183
+            sig_path,
184
+            progress=None,
185
+        )
186
+
187
+        assert bytes_received == 9
188
+        assert pack_path.read_bytes() == b"url-pack"
189
+        assert seen == {
190
+            "url": "https://example.test/pack",
191
+            "progress": None,
192
+            "sidecar_url": "https://example.test/pack",
193
+            "sidecar_path": sig_path,
194
+        }
195
+
196
+    def test_dispatch_pull_peer_downloads_pack_and_sidecar(
197
+        self,
198
+        tmp_path: Path,
199
+        monkeypatch: pytest.MonkeyPatch,
200
+    ) -> None:
201
+        import dlm.share.peer as peer
202
+
203
+        pack_path = tmp_path / "pack.dlm.pack"
204
+        sig_path = tmp_path / "pack.dlm.pack.minisig"
205
+        seen: dict[str, object] = {}
206
+
207
+        def _fake_pull_peer(target: str, out_path: Path, *, progress: object | None = None) -> int:
208
+            seen["target"] = target
209
+            seen["progress"] = progress
210
+            out_path.write_bytes(b"peer-pack")
211
+            return 11
212
+
213
+        monkeypatch.setattr(peer, "pull_peer", _fake_pull_peer)
214
+        monkeypatch.setattr(
215
+            pull_mod,
216
+            "_try_peer_sidecar",
217
+            lambda target, sidecar_path: seen.update(
218
+                {"sidecar_target": target, "sidecar_path": sidecar_path}
219
+            ),
220
+        )
221
+
222
+        bytes_received = _dispatch_pull(
223
+            SinkSpec(kind=SinkKind.PEER, target="host:7337/pack?token=abc"),
224
+            pack_path,
225
+            sig_path,
226
+            progress=None,
227
+        )
228
+
229
+        assert bytes_received == 11
230
+        assert pack_path.read_bytes() == b"peer-pack"
231
+        assert seen == {
232
+            "target": "host:7337/pack?token=abc",
233
+            "progress": None,
234
+            "sidecar_target": "host:7337/pack?token=abc",
235
+            "sidecar_path": sig_path,
236
+        }
237
+
238
+    def test_dispatch_pull_local_copies_pack_and_signature(self, tmp_path: Path) -> None:
239
+        src = tmp_path / "src.dlm.pack"
240
+        sig = tmp_path / "src.dlm.pack.minisig"
241
+        src.write_bytes(b"local-pack")
242
+        sig.write_text("local-signature", encoding="utf-8")
243
+        pack_path = tmp_path / "incoming.dlm.pack"
244
+        sig_path = tmp_path / "incoming.dlm.pack.minisig"
245
+
246
+        bytes_received = _dispatch_pull(
247
+            SinkSpec(kind=SinkKind.LOCAL, target=str(src)),
248
+            pack_path,
249
+            sig_path,
250
+            progress=None,
251
+        )
252
+
253
+        assert bytes_received == len(b"local-pack")
254
+        assert pack_path.read_bytes() == b"local-pack"
255
+        assert sig_path.read_text(encoding="utf-8") == "local-signature"
256
+
257
+    def test_dispatch_pull_local_missing_source_raises(self, tmp_path: Path) -> None:
258
+        with pytest.raises(SinkError, match="source missing"):
259
+            _dispatch_pull(
260
+                SinkSpec(kind=SinkKind.LOCAL, target=str(tmp_path / "missing.dlm.pack")),
261
+                tmp_path / "incoming.dlm.pack",
262
+                tmp_path / "incoming.dlm.pack.minisig",
263
+                progress=None,
264
+            )
265
+
266
+    def test_dispatch_pull_rejects_unsupported_kind(self, tmp_path: Path) -> None:
267
+        weird = SinkSpec(kind=cast("SinkKind", "weird"), target="x")
268
+
269
+        with pytest.raises(ShareError, match="unsupported sink kind"):
270
+            _dispatch_pull(
271
+                weird,
272
+                tmp_path / "incoming.dlm.pack",
273
+                tmp_path / "incoming.dlm.pack.minisig",
274
+                progress=None,
275
+            )
276
+
277
+
278
+class TestPullSidecars:
279
+    def test_try_hf_sidecar_copies_downloaded_signature(
280
+        self,
281
+        tmp_path: Path,
282
+        monkeypatch: pytest.MonkeyPatch,
283
+    ) -> None:
284
+        fake_hub = ModuleType("huggingface_hub")
285
+        fake_utils = ModuleType("huggingface_hub.utils")
286
+        downloaded = tmp_path / "downloaded.minisig"
287
+        downloaded.write_text("hf-signature", encoding="utf-8")
288
+
289
+        class FakeHfHubHTTPError(Exception):
290
+            pass
291
+
292
+        def _fake_download(*, repo_id: str, filename: str, repo_type: str) -> str:
293
+            assert repo_id == "org/repo"
294
+            assert filename == "adapter.dlm.pack.minisig"
295
+            assert repo_type == "model"
296
+            return str(downloaded)
297
+
298
+        fake_hub.hf_hub_download = _fake_download
299
+        fake_utils.HfHubHTTPError = FakeHfHubHTTPError
300
+        monkeypatch.setitem(sys.modules, "huggingface_hub", fake_hub)
301
+        monkeypatch.setitem(sys.modules, "huggingface_hub.utils", fake_utils)
302
+
303
+        sig_path = tmp_path / "incoming.minisig"
304
+        _try_hf_sidecar("org/repo", sig_path)
305
+
306
+        assert sig_path.read_text(encoding="utf-8") == "hf-signature"
307
+
308
+    def test_try_hf_sidecar_suppresses_hub_errors(
309
+        self,
310
+        tmp_path: Path,
311
+        monkeypatch: pytest.MonkeyPatch,
312
+    ) -> None:
313
+        fake_hub = ModuleType("huggingface_hub")
314
+        fake_utils = ModuleType("huggingface_hub.utils")
315
+
316
+        class FakeHfHubHTTPError(Exception):
317
+            pass
318
+
319
+        def _fake_download(*, repo_id: str, filename: str, repo_type: str) -> str:
320
+            raise FakeHfHubHTTPError("missing")
321
+
322
+        fake_hub.hf_hub_download = _fake_download
323
+        fake_utils.HfHubHTTPError = FakeHfHubHTTPError
324
+        monkeypatch.setitem(sys.modules, "huggingface_hub", fake_hub)
325
+        monkeypatch.setitem(sys.modules, "huggingface_hub.utils", fake_utils)
326
+
327
+        sig_path = tmp_path / "incoming.minisig"
328
+        _try_hf_sidecar("org/repo", sig_path)
329
+
330
+        assert not sig_path.exists()
331
+
332
+    def test_try_url_sidecar_suppresses_missing_sidecar(
333
+        self,
334
+        tmp_path: Path,
335
+        monkeypatch: pytest.MonkeyPatch,
336
+    ) -> None:
337
+        import dlm.share.url_sink as url_sink
338
+
339
+        def _fake_pull_url(url: str, out_path: Path, *, progress: object | None = None) -> int:
340
+            raise SinkError(f"missing {url}")
341
+
342
+        monkeypatch.setattr(url_sink, "pull_url", _fake_pull_url)
343
+
344
+        _try_url_sidecar("https://example.test/pack", tmp_path / "incoming.minisig")
345
+
346
+    def test_try_peer_sidecar_suppresses_missing_sidecar(
347
+        self,
348
+        tmp_path: Path,
349
+        monkeypatch: pytest.MonkeyPatch,
350
+    ) -> None:
351
+        import dlm.share.peer as peer
352
+
353
+        def _fake_pull_peer(target: str, out_path: Path, *, progress: object | None = None) -> int:
354
+            raise SinkError(f"missing {target}")
355
+
356
+        monkeypatch.setattr(peer, "pull_peer", _fake_pull_peer)
357
+
358
+        _try_peer_sidecar("host:7337/pack?token=abc", tmp_path / "incoming.minisig")
359
+
360
+
361
+class TestVerificationLogging:
362
+    def test_log_verification_verified(
363
+        self, caplog: pytest.LogCaptureFixture, tmp_path: Path
364
+    ) -> None:
365
+        caplog.set_level("INFO")
366
+
367
+        _log_verification(
368
+            "hf:org/repo",
369
+            VerifyResult(status=VerifyStatus.VERIFIED, key_path=tmp_path / "trusted.pub"),
370
+        )
371
+
372
+        assert "verified signature" in caplog.text
373
+
374
+    def test_log_verification_unverified(self, caplog: pytest.LogCaptureFixture) -> None:
375
+        caplog.set_level("WARNING")
376
+
377
+        _log_verification(
378
+            "https://example.test/pack",
379
+            VerifyResult(status=VerifyStatus.UNVERIFIED, detail="no trusted key matched"),
380
+        )
381
+
382
+        assert "signature present but could not verify" in caplog.text
383
+        assert "no trusted key matched" in caplog.text
384
+
385
+    def test_log_verification_unsigned(self, caplog: pytest.LogCaptureFixture) -> None:
386
+        caplog.set_level("INFO")
387
+
388
+        _log_verification(
389
+            "./local.dlm.pack",
390
+            VerifyResult(status=VerifyStatus.UNSIGNED),
391
+        )
392
+
393
+        assert "no signature" in caplog.text
tests/unit/share/test_push.pyadded
435 lines changed — click to load
@@ -0,0 +1,435 @@
1
+"""Unit coverage for the share push orchestrator."""
2
+
3
+from __future__ import annotations
4
+
5
+import importlib
6
+import io
7
+import json
8
+import tarfile
9
+from pathlib import Path
10
+from types import SimpleNamespace
11
+from typing import cast
12
+
13
+import pytest
14
+import zstandard as zstd
15
+
16
+from dlm.share.errors import ShareError, SinkError
17
+from dlm.share.push import (
18
+    PushResult,
19
+    _collect_readme_fields,
20
+    _dispatch_push,
21
+    _ensure_pack,
22
+    _noop,
23
+    _sign_pack,
24
+    push,
25
+)
26
+from dlm.share.sinks import SinkKind, SinkSpec
27
+
28
+push_mod = importlib.import_module("dlm.share.push")
29
+
30
+
31
+def _write_pack_with_header(tmp_path: Path, header: dict[str, str]) -> Path:
32
+    tar_bytes = io.BytesIO()
33
+    with tarfile.open(fileobj=tar_bytes, mode="w") as tar:
34
+        payload = json.dumps(header).encode("utf-8")
35
+        info = tarfile.TarInfo("pack/header.json")
36
+        info.size = len(payload)
37
+        tar.addfile(info, io.BytesIO(payload))
38
+    pack_path = tmp_path / "bundle.dlm.pack"
39
+    with pack_path.open("wb") as dst, zstd.ZstdCompressor().stream_writer(dst) as writer:
40
+        writer.write(tar_bytes.getvalue())
41
+    return pack_path
42
+
43
+
44
+class TestPush:
45
+    def test_push_rejects_peer_destinations(
46
+        self,
47
+        tmp_path: Path,
48
+        monkeypatch: pytest.MonkeyPatch,
49
+    ) -> None:
50
+        monkeypatch.setattr(
51
+            push_mod,
52
+            "parse_source",
53
+            lambda destination: SinkSpec(kind=SinkKind.PEER, target=destination),
54
+        )
55
+
56
+        with pytest.raises(ShareError, match="push to peer:// is not supported"):
57
+            push(tmp_path / "doc.dlm", "peer://host:7337/doc?token=abc")
58
+
59
+    def test_push_signs_dispatches_and_cleans_up(
60
+        self,
61
+        tmp_path: Path,
62
+        monkeypatch: pytest.MonkeyPatch,
63
+    ) -> None:
64
+        source = tmp_path / "doc.dlm"
65
+        source.write_text("body", encoding="utf-8")
66
+        pack_path = tmp_path / "doc.dlm.pack"
67
+        cleanup_called = False
68
+        order: list[str] = []
69
+        progress = object()
70
+        expected = PushResult(
71
+            destination="https://example.test/upload",
72
+            sink_kind=SinkKind.URL,
73
+            bytes_sent=11,
74
+        )
75
+
76
+        monkeypatch.setattr(
77
+            push_mod,
78
+            "parse_source",
79
+            lambda destination: SinkSpec(kind=SinkKind.URL, target=destination),
80
+        )
81
+
82
+        def _fake_ensure_pack(
83
+            actual_source: Path,
84
+            *,
85
+            include_exports: bool,
86
+            include_base: bool,
87
+            include_logs: bool,
88
+            licensee_acceptance_url: str | None,
89
+        ) -> tuple[Path, object]:
90
+            nonlocal cleanup_called
91
+            assert actual_source == source
92
+            assert include_exports is True
93
+            assert include_base is True
94
+            assert include_logs is True
95
+            assert licensee_acceptance_url == "https://license.example/accept"
96
+            pack_path.write_bytes(b"packed-bytes")
97
+
98
+            def _cleanup() -> None:
99
+                nonlocal cleanup_called
100
+                cleanup_called = True
101
+
102
+            return pack_path, _cleanup
103
+
104
+        def _fake_sign_pack(actual_pack: Path) -> None:
105
+            order.append("sign")
106
+            assert actual_pack == pack_path
107
+
108
+        def _fake_dispatch(
109
+            actual_pack: Path,
110
+            spec: SinkSpec,
111
+            *,
112
+            progress: object | None,
113
+        ) -> PushResult:
114
+            order.append("dispatch")
115
+            assert actual_pack == pack_path
116
+            assert spec == SinkSpec(
117
+                kind=SinkKind.URL,
118
+                target="https://example.test/upload",
119
+            )
120
+            assert progress is not None
121
+            return expected
122
+
123
+        monkeypatch.setattr(push_mod, "_ensure_pack", _fake_ensure_pack)
124
+        monkeypatch.setattr(push_mod, "_sign_pack", _fake_sign_pack)
125
+        monkeypatch.setattr(push_mod, "_dispatch_push", _fake_dispatch)
126
+
127
+        result = push(
128
+            source,
129
+            "https://example.test/upload",
130
+            sign=True,
131
+            include_exports=True,
132
+            include_base=True,
133
+            include_logs=True,
134
+            licensee_acceptance_url="https://license.example/accept",
135
+            progress=cast("object", progress),
136
+        )
137
+
138
+        assert result == expected
139
+        assert order == ["sign", "dispatch"]
140
+        assert cleanup_called is True
141
+
142
+    def test_push_cleans_up_when_dispatch_raises(
143
+        self,
144
+        tmp_path: Path,
145
+        monkeypatch: pytest.MonkeyPatch,
146
+    ) -> None:
147
+        source = tmp_path / "doc.dlm"
148
+        source.write_text("body", encoding="utf-8")
149
+        pack_path = tmp_path / "doc.dlm.pack"
150
+        cleanup_called = False
151
+
152
+        monkeypatch.setattr(
153
+            push_mod,
154
+            "parse_source",
155
+            lambda destination: SinkSpec(kind=SinkKind.URL, target=destination),
156
+        )
157
+        monkeypatch.setattr(
158
+            push_mod,
159
+            "_ensure_pack",
160
+            lambda *args, **kwargs: (
161
+                pack_path,
162
+                lambda: globals().__setitem__("_unused", None),
163
+            ),
164
+        )
165
+
166
+        def _cleanup() -> None:
167
+            nonlocal cleanup_called
168
+            cleanup_called = True
169
+
170
+        monkeypatch.setattr(push_mod, "_ensure_pack", lambda *args, **kwargs: (pack_path, _cleanup))
171
+        monkeypatch.setattr(
172
+            push_mod,
173
+            "_dispatch_push",
174
+            lambda *args, **kwargs: (_ for _ in ()).throw(SinkError("boom")),
175
+        )
176
+
177
+        with pytest.raises(SinkError, match="boom"):
178
+            push(source, "https://example.test/upload")
179
+
180
+        assert cleanup_called is True
181
+
182
+
183
+class TestEnsurePack:
184
+    def test_ensure_pack_keeps_existing_pack(self, tmp_path: Path) -> None:
185
+        pack_path = tmp_path / "doc.dlm.pack"
186
+        pack_path.write_bytes(b"already-packed")
187
+
188
+        actual_path, cleanup = _ensure_pack(
189
+            pack_path,
190
+            include_exports=False,
191
+            include_base=False,
192
+            include_logs=False,
193
+            licensee_acceptance_url=None,
194
+        )
195
+
196
+        assert actual_path == pack_path
197
+        assert cleanup is _noop
198
+
199
+    def test_ensure_pack_packs_dlm_and_cleans_up(
200
+        self,
201
+        tmp_path: Path,
202
+        monkeypatch: pytest.MonkeyPatch,
203
+    ) -> None:
204
+        source = tmp_path / "doc.dlm"
205
+        source.write_text("body", encoding="utf-8")
206
+        seen: dict[str, object] = {}
207
+
208
+        def _fake_pack(
209
+            actual_source: Path,
210
+            *,
211
+            out: Path,
212
+            include_exports: bool,
213
+            include_base: bool,
214
+            include_logs: bool,
215
+            licensee_acceptance_url: str | None,
216
+        ) -> SimpleNamespace:
217
+            seen["source"] = actual_source
218
+            seen["out"] = out
219
+            seen["include_exports"] = include_exports
220
+            seen["include_base"] = include_base
221
+            seen["include_logs"] = include_logs
222
+            seen["license"] = licensee_acceptance_url
223
+            out.write_bytes(b"packed")
224
+            return SimpleNamespace(path=out)
225
+
226
+        monkeypatch.setattr(push_mod, "pack", _fake_pack)
227
+
228
+        actual_path, cleanup = _ensure_pack(
229
+            source,
230
+            include_exports=True,
231
+            include_base=True,
232
+            include_logs=True,
233
+            licensee_acceptance_url="https://license.example/accept",
234
+        )
235
+
236
+        temp_dir = actual_path.parent
237
+        assert actual_path.read_bytes() == b"packed"
238
+        assert seen == {
239
+            "source": source,
240
+            "out": actual_path,
241
+            "include_exports": True,
242
+            "include_base": True,
243
+            "include_logs": True,
244
+            "license": "https://license.example/accept",
245
+        }
246
+
247
+        cleanup()
248
+        assert not temp_dir.exists()
249
+
250
+
251
+class TestSignPack:
252
+    def test_sign_pack_calls_sign_file(
253
+        self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
254
+    ) -> None:
255
+        import dlm.share.signing as signing
256
+
257
+        pack_path = tmp_path / "bundle.dlm.pack"
258
+        pack_path.write_bytes(b"packed")
259
+        sig_path = pack_path.with_suffix(pack_path.suffix + ".minisig")
260
+        seen: dict[str, object] = {}
261
+
262
+        def _fake_sign_file(target: Path, *, comment: str | None = None) -> Path:
263
+            seen["target"] = target
264
+            seen["comment"] = comment
265
+            sig_path.write_text("signature", encoding="utf-8")
266
+            return sig_path
267
+
268
+        monkeypatch.setattr(signing, "sign_file", _fake_sign_file)
269
+
270
+        _sign_pack(pack_path)
271
+
272
+        assert seen == {
273
+            "target": pack_path,
274
+            "comment": f"dlm push {pack_path.name}",
275
+        }
276
+
277
+    def test_sign_pack_propagates_missing_minisign(
278
+        self,
279
+        tmp_path: Path,
280
+        monkeypatch: pytest.MonkeyPatch,
281
+    ) -> None:
282
+        import dlm.share.signing as signing
283
+
284
+        pack_path = tmp_path / "bundle.dlm.pack"
285
+        pack_path.write_bytes(b"packed")
286
+
287
+        def _fake_sign_file(target: Path, *, comment: str | None = None) -> Path:
288
+            raise signing.MinisignNotAvailableError("missing")
289
+
290
+        monkeypatch.setattr(signing, "sign_file", _fake_sign_file)
291
+
292
+        with pytest.raises(signing.MinisignNotAvailableError, match="missing"):
293
+            _sign_pack(pack_path)
294
+
295
+
296
+class TestDispatchPush:
297
+    def test_dispatch_push_hf_uploads_pack(
298
+        self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
299
+    ) -> None:
300
+        import dlm.share.hf_sink as hf_sink
301
+
302
+        pack_path = tmp_path / "bundle.dlm.pack"
303
+        pack_path.write_bytes(b"packed")
304
+        progress = object()
305
+        seen: dict[str, object] = {}
306
+
307
+        def _fake_push_hf(
308
+            actual_pack: Path,
309
+            repo_id: str,
310
+            *,
311
+            private: bool = False,
312
+            readme_fields: dict[str, str] | None = None,
313
+            progress: object | None = None,
314
+        ) -> SimpleNamespace:
315
+            seen["pack"] = actual_pack
316
+            seen["repo_id"] = repo_id
317
+            seen["private"] = private
318
+            seen["readme_fields"] = readme_fields
319
+            seen["progress"] = progress
320
+            return SimpleNamespace(
321
+                pack_url="https://huggingface.co/org/repo/blob/main/adapter.dlm.pack"
322
+            )
323
+
324
+        monkeypatch.setattr(hf_sink, "push_hf", _fake_push_hf)
325
+        monkeypatch.setattr(
326
+            push_mod,
327
+            "_collect_readme_fields",
328
+            lambda path: {"dlm_id": "01HZPUSH", "base_model": "qwen3-4b"},
329
+        )
330
+
331
+        result = _dispatch_push(
332
+            pack_path,
333
+            SinkSpec(kind=SinkKind.HF, target="org/repo"),
334
+            progress=cast("object", progress),
335
+        )
336
+
337
+        assert result == PushResult(
338
+            destination="hf:org/repo",
339
+            sink_kind=SinkKind.HF,
340
+            bytes_sent=len(b"packed"),
341
+            detail="pack: https://huggingface.co/org/repo/blob/main/adapter.dlm.pack",
342
+        )
343
+        assert seen == {
344
+            "pack": pack_path,
345
+            "repo_id": "org/repo",
346
+            "private": False,
347
+            "readme_fields": {"dlm_id": "01HZPUSH", "base_model": "qwen3-4b"},
348
+            "progress": progress,
349
+        }
350
+
351
+    def test_dispatch_push_url_uploads_pack(
352
+        self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
353
+    ) -> None:
354
+        import dlm.share.url_sink as url_sink
355
+
356
+        pack_path = tmp_path / "bundle.dlm.pack"
357
+        pack_path.write_bytes(b"packed")
358
+        seen: dict[str, object] = {}
359
+
360
+        def _fake_push_url(actual_pack: Path, url: str, *, progress: object | None = None) -> None:
361
+            seen["pack"] = actual_pack
362
+            seen["url"] = url
363
+            seen["progress"] = progress
364
+
365
+        monkeypatch.setattr(url_sink, "push_url", _fake_push_url)
366
+
367
+        result = _dispatch_push(
368
+            pack_path,
369
+            SinkSpec(kind=SinkKind.URL, target="https://example.test/upload"),
370
+            progress=None,
371
+        )
372
+
373
+        assert result == PushResult(
374
+            destination="https://example.test/upload",
375
+            sink_kind=SinkKind.URL,
376
+            bytes_sent=len(b"packed"),
377
+        )
378
+        assert seen == {
379
+            "pack": pack_path,
380
+            "url": "https://example.test/upload",
381
+            "progress": None,
382
+        }
383
+
384
+    def test_dispatch_push_local_copies_pack(self, tmp_path: Path) -> None:
385
+        pack_path = tmp_path / "bundle.dlm.pack"
386
+        pack_path.write_bytes(b"packed")
387
+        dest = tmp_path / "nested" / "copy.dlm.pack"
388
+
389
+        result = _dispatch_push(
390
+            pack_path,
391
+            SinkSpec(kind=SinkKind.LOCAL, target=str(dest)),
392
+            progress=None,
393
+        )
394
+
395
+        assert result == PushResult(
396
+            destination=str(dest),
397
+            sink_kind=SinkKind.LOCAL,
398
+            bytes_sent=len(b"packed"),
399
+        )
400
+        assert dest.read_bytes() == b"packed"
401
+
402
+    def test_dispatch_push_rejects_unsupported_kind(self, tmp_path: Path) -> None:
403
+        pack_path = tmp_path / "bundle.dlm.pack"
404
+        pack_path.write_bytes(b"packed")
405
+
406
+        with pytest.raises(SinkError, match="unsupported sink kind"):
407
+            _dispatch_push(
408
+                pack_path,
409
+                SinkSpec(kind=cast("SinkKind", "weird"), target="x"),
410
+                progress=None,
411
+            )
412
+
413
+
414
+class TestReadmeFields:
415
+    def test_collect_readme_fields_from_pack(self, tmp_path: Path) -> None:
416
+        pack_path = _write_pack_with_header(
417
+            tmp_path,
418
+            {
419
+                "dlm_id": "01HZHEADER",
420
+                "base_model": "qwen3-8b",
421
+                "adapter_version": "v0007",
422
+            },
423
+        )
424
+
425
+        assert _collect_readme_fields(pack_path) == {
426
+            "dlm_id": "01HZHEADER",
427
+            "base_model": "qwen3-8b",
428
+            "adapter_version": "v0007",
429
+        }
430
+
431
+    def test_collect_readme_fields_returns_empty_on_bad_pack(self, tmp_path: Path) -> None:
432
+        assert _collect_readme_fields(tmp_path / "missing.dlm.pack") == {}
433
+
434
+    def test_noop_is_noop(self) -> None:
435
+        assert _noop() is None
tests/unit/share/test_signing.pymodified
149 lines changed — click to load
@@ -7,6 +7,7 @@ from unittest.mock import patch
77
 
88
 import pytest
99
 
10
+from dlm.share.errors import ShareError
1011
 from dlm.share.signing import (
1112
     MinisignNotAvailableError,
1213
     VerifyResult,
@@ -81,3 +82,142 @@ class TestSignRefusesWithoutBinary:
8182
             pytest.raises(MinisignNotAvailableError, match="not installed"),
8283
         ):
8384
             sign_file(target)
85
+
86
+
87
+class TestSignFile:
88
+    def test_missing_secret_key_is_refused(self, tmp_path: Path) -> None:
89
+        target = tmp_path / "pack.bin"
90
+        target.write_bytes(b"payload")
91
+
92
+        with (
93
+            patch("dlm.share.signing.minisign_available", return_value=True),
94
+            pytest.raises(Exception, match="secret key not found"),
95
+        ):
96
+            sign_file(target, secret_key=tmp_path / "missing.key")
97
+
98
+    def test_nonzero_exit_is_refused(self, tmp_path: Path) -> None:
99
+        target = tmp_path / "pack.bin"
100
+        target.write_bytes(b"payload")
101
+        secret = tmp_path / "secret.key"
102
+        secret.write_text("key", encoding="utf-8")
103
+
104
+        class Result:
105
+            returncode = 7
106
+
107
+        with (
108
+            patch("dlm.share.signing.minisign_available", return_value=True),
109
+            patch("subprocess.run", return_value=Result()),
110
+            pytest.raises(Exception, match="exit 7"),
111
+        ):
112
+            sign_file(target, secret_key=secret, comment="demo")
113
+
114
+    def test_missing_signature_sidecar_after_success_is_refused(self, tmp_path: Path) -> None:
115
+        target = tmp_path / "pack.bin"
116
+        target.write_bytes(b"payload")
117
+        secret = tmp_path / "secret.key"
118
+        secret.write_text("key", encoding="utf-8")
119
+
120
+        class Result:
121
+            returncode = 0
122
+
123
+        with (
124
+            patch("dlm.share.signing.minisign_available", return_value=True),
125
+            patch("subprocess.run", return_value=Result()),
126
+            pytest.raises(Exception, match="is missing"),
127
+        ):
128
+            sign_file(target, secret_key=secret)
129
+
130
+    def test_happy_path_returns_minisig_path(self, tmp_path: Path) -> None:
131
+        target = tmp_path / "pack.bin"
132
+        target.write_bytes(b"payload")
133
+        secret = tmp_path / "secret.key"
134
+        secret.write_text("key", encoding="utf-8")
135
+        sig = target.with_suffix(target.suffix + ".minisig")
136
+
137
+        class Result:
138
+            returncode = 0
139
+
140
+        def _fake_run(cmd: list[str], check: bool) -> Result:
141
+            assert "-c" in cmd
142
+            sig.write_text("signature", encoding="utf-8")
143
+            return Result()
144
+
145
+        with (
146
+            patch("dlm.share.signing.minisign_available", return_value=True),
147
+            patch("subprocess.run", side_effect=_fake_run),
148
+        ):
149
+            out = sign_file(target, secret_key=secret, comment="demo")
150
+
151
+        assert out == sig
152
+
153
+
154
+class TestVerifySignature:
155
+    def test_verified_when_one_key_matches(self, tmp_path: Path) -> None:
156
+        target = tmp_path / "pack.bin"
157
+        target.write_bytes(b"payload")
158
+        sig = tmp_path / "pack.bin.minisig"
159
+        sig.write_bytes(b"sig")
160
+        keys = tmp_path / "trusted-keys"
161
+        keys.mkdir()
162
+        miss = keys / "miss.pub"
163
+        hit = keys / "hit.pub"
164
+        miss.write_text("miss", encoding="utf-8")
165
+        hit.write_text("hit", encoding="utf-8")
166
+        seen: list[Path] = []
167
+
168
+        def _fake_verify(_target: Path, _sig: Path, pub_key: Path) -> None:
169
+            seen.append(pub_key)
170
+            if pub_key == miss:
171
+                raise Exception("bad key")
172
+
173
+        with (
174
+            patch("dlm.share.signing.minisign_available", return_value=True),
175
+            patch("dlm.share.signing._minisign_verify", side_effect=_fake_verify),
176
+        ):
177
+            result = verify_signature(target, sig, trusted_keys_dir=keys)
178
+
179
+        assert result.status == VerifyStatus.VERIFIED
180
+        assert result.key_path == hit
181
+        assert seen == [hit]
182
+
183
+    def test_unverified_when_no_keys_match(self, tmp_path: Path) -> None:
184
+        target = tmp_path / "pack.bin"
185
+        target.write_bytes(b"payload")
186
+        sig = tmp_path / "pack.bin.minisig"
187
+        sig.write_bytes(b"sig")
188
+        keys = tmp_path / "trusted-keys"
189
+        keys.mkdir()
190
+        (keys / "a.pub").write_text("a", encoding="utf-8")
191
+        (keys / "b.pub").write_text("b", encoding="utf-8")
192
+
193
+        with (
194
+            patch("dlm.share.signing.minisign_available", return_value=True),
195
+            patch("dlm.share.signing._minisign_verify", side_effect=ShareError("no match")),
196
+        ):
197
+            result = verify_signature(target, sig, trusted_keys_dir=keys)
198
+
199
+        assert result.status == VerifyStatus.UNVERIFIED
200
+        assert "no match among 2 trusted keys" in result.detail
201
+
202
+
203
+class TestMinisignVerify:
204
+    def test_verify_raises_share_error_on_nonzero_exit(self, tmp_path: Path) -> None:
205
+        from dlm.share.errors import ShareError
206
+        from dlm.share.signing import _minisign_verify
207
+
208
+        target = tmp_path / "pack.bin"
209
+        target.write_bytes(b"payload")
210
+        sig = tmp_path / "pack.bin.minisig"
211
+        sig.write_bytes(b"sig")
212
+        key = tmp_path / "key.pub"
213
+        key.write_text("key", encoding="utf-8")
214
+
215
+        class Result:
216
+            returncode = 1
217
+            stderr = b"bad signature"
218
+
219
+        with (
220
+            patch("subprocess.run", return_value=Result()),
221
+            pytest.raises(ShareError, match="bad signature"),
222
+        ):
223
+            _minisign_verify(target, sig, key)
tests/unit/share/test_url_sink.pymodified
93 lines changed — click to load
@@ -124,6 +124,32 @@ class TestPushUrl:
124124
         with pytest.raises(SinkError, match="network error"):
125125
             push_url(pack, "https://example.com/upload")
126126
 
127
+    def test_non_2xx_response_object_is_refused(
128
+        self, pack: Path, monkeypatch: pytest.MonkeyPatch
129
+    ) -> None:
130
+        def _fake_urlopen(req: urllib.request.Request, data: object, timeout: int) -> _FakeResponse:
131
+            return _FakeResponse(status=500)
132
+
133
+        monkeypatch.setattr(urllib.request, "urlopen", _fake_urlopen)
134
+
135
+        with pytest.raises(SinkError, match="HTTP 500"):
136
+            push_url(pack, "https://example.com/upload")
137
+
138
+    def test_io_error_reading_pack_is_translated(
139
+        self, pack: Path, monkeypatch: pytest.MonkeyPatch
140
+    ) -> None:
141
+        real_open = Path.open
142
+
143
+        def _broken_open(self: Path, *args: object, **kwargs: object):
144
+            if self == pack:
145
+                raise OSError("disk error")
146
+            return real_open(self, *args, **kwargs)
147
+
148
+        monkeypatch.setattr(Path, "open", _broken_open)
149
+
150
+        with pytest.raises(SinkError, match="I/O error reading"):
151
+            push_url(pack, "https://example.com/upload")
152
+
127153
     def test_progress_called_at_start_and_end(
128154
         self, pack: Path, monkeypatch: pytest.MonkeyPatch
129155
     ) -> None:
@@ -216,6 +242,61 @@ class TestPullUrl:
216242
         with pytest.raises(SinkError, match="HTTP 404"):
217243
             pull_url("https://example.com/p", tmp_path / "out.pack")
218244
 
245
+    def test_non_2xx_response_object_is_refused(
246
+        self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
247
+    ) -> None:
248
+        def _fake_urlopen(req: urllib.request.Request, timeout: int) -> _FakeResponse:
249
+            return _FakeResponse(status=503, body=b"down")
250
+
251
+        monkeypatch.setattr(urllib.request, "urlopen", _fake_urlopen)
252
+        with pytest.raises(SinkError, match="HTTP 503"):
253
+            pull_url("https://example.com/p", tmp_path / "out.pack")
254
+
255
+    def test_network_error_is_translated(
256
+        self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
257
+    ) -> None:
258
+        def _fake_urlopen(req: urllib.request.Request, timeout: int) -> _FakeResponse:
259
+            raise urllib.error.URLError("reset")
260
+
261
+        monkeypatch.setattr(urllib.request, "urlopen", _fake_urlopen)
262
+        with pytest.raises(SinkError, match="network error"):
263
+            pull_url("https://example.com/p", tmp_path / "out.pack")
264
+
265
+    def test_io_error_writing_is_translated(
266
+        self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
267
+    ) -> None:
268
+        out = tmp_path / "nested" / "out.pack"
269
+
270
+        def _fake_urlopen(req: urllib.request.Request, timeout: int) -> _FakeResponse:
271
+            return _FakeResponse(status=200, body=b"payload", headers={"Content-Length": "7"})
272
+
273
+        real_open = Path.open
274
+
275
+        def _broken_open(self: Path, *args: object, **kwargs: object):
276
+            if self == out:
277
+                raise OSError("read only")
278
+            return real_open(self, *args, **kwargs)
279
+
280
+        monkeypatch.setattr(urllib.request, "urlopen", _fake_urlopen)
281
+        monkeypatch.setattr(Path, "open", _broken_open)
282
+
283
+        with pytest.raises(SinkError, match="I/O error writing"):
284
+            pull_url("https://example.com/p", out)
285
+
286
+    def test_http_scheme_warns(
287
+        self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
288
+    ) -> None:
289
+        out = tmp_path / "fetched.pack"
290
+
291
+        def _fake_urlopen(req: urllib.request.Request, timeout: int) -> _FakeResponse:
292
+            return _FakeResponse(status=200, body=b"x", headers={"Content-Length": "1"})
293
+
294
+        monkeypatch.setattr(urllib.request, "urlopen", _fake_urlopen)
295
+        with caplog.at_level("WARNING", logger="dlm.share.url_sink"):
296
+            pull_url("http://example.com/p", out)
297
+
298
+        assert any("plaintext HTTP" in rec.message for rec in caplog.records)
299
+
219300
     def test_creates_parent_dir(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
220301
         out = tmp_path / "nested" / "dir" / "fetched.pack"
221302
 
tests/unit/store/test_blobs.pymodified
47 lines changed — click to load
@@ -90,6 +90,12 @@ class TestBlobStorePut:
9090
         from_bytes = store.put_bytes(data, ext=".jpg")
9191
         assert from_path == from_bytes
9292
 
93
+    def test_put_bytes_writes_new_blob(self, store: BlobStore, store_root: Path) -> None:
94
+        data = b"raw bytes"
95
+        handle = store.put_bytes(data, ext="PNG")
96
+        blob_path = store_root / handle.sha[:2] / f"{handle.sha}.png"
97
+        assert blob_path.read_bytes() == data
98
+
9399
 
94100
 class TestBlobStoreGet:
95101
     def test_get_returns_stored_path(self, store: BlobStore, tmp_path: Path) -> None:
@@ -145,6 +151,14 @@ class TestBlobStoreGC:
145151
     def test_gc_noop_on_empty_store(self, store: BlobStore) -> None:
146152
         assert list(store.gc({"a" * 64})) == []
147153
 
154
+    def test_gc_ignores_concurrent_delete(self, store: BlobStore, tmp_path: Path) -> None:
155
+        src = tmp_path / "x.png"
156
+        src.write_bytes(b"x")
157
+        handle = store.put(src)
158
+        store.get(handle.sha).unlink()
159
+        store.iter_all = lambda: iter([handle])  # type: ignore[method-assign]
160
+        assert list(store.gc(set())) == []
161
+
148162
 
149163
 class TestBlobStoreExtensions:
150164
     @pytest.mark.parametrize(
@@ -208,6 +222,13 @@ class TestBlobStoreIteration:
208222
         iterated = list(store.iter_all())
209223
         assert sorted(h.sha for h in handles) == sorted(h.sha for h in iterated)
210224
 
225
+    def test_iter_all_ignores_non_blob_entries(self, store: BlobStore, store_root: Path) -> None:
226
+        bucket = store_root / "aa"
227
+        bucket.mkdir(parents=True, exist_ok=True)
228
+        (store_root / "README.txt").write_text("ignore me", encoding="utf-8")
229
+        (bucket / "nested").mkdir()
230
+        assert list(store.iter_all()) == []
231
+
211232
 
212233
 class TestBlobStoreClear:
213234
     def test_clear_removes_tree(self, store: BlobStore, store_root: Path, tmp_path: Path) -> None:
@@ -224,3 +245,8 @@ class TestBlobHandleValue:
224245
         h = BlobHandle(sha="a" * 64, ext=".png", size=10)
225246
         with pytest.raises(AttributeError):
226247
             h.sha = "b" * 64  # type: ignore[misc]
248
+
249
+
250
+class TestBlobStoreMetadata:
251
+    def test_root_property(self, store: BlobStore, store_root: Path) -> None:
252
+        assert store.root == store_root
tests/unit/store/test_inspect.pymodified
66 lines changed — click to load
@@ -7,7 +7,7 @@ from pathlib import Path
77
 
88
 import pytest
99
 
10
-from dlm.store.inspect import inspect_store
10
+from dlm.store.inspect import _directory_size, _discover_named_adapters, _max_version, inspect_store
1111
 from dlm.store.manifest import Manifest, TrainingRunSummary, save_manifest
1212
 from dlm.store.paths import StorePath, for_dlm
1313
 from tests.fixtures.dlm_factory import make_dlm
@@ -203,3 +203,58 @@ class TestTimelineEdges:
203203
         save_manifest(store.manifest, manifest)
204204
         result = inspect_store(store)
205205
         assert result.last_trained_at == base + timedelta(minutes=5)
206
+
207
+
208
+class TestInspectCoverageEdges:
209
+    def test_discover_named_adapters_on_missing_adapter_dir(self, tmp_path: Path) -> None:
210
+        store = for_dlm(VALID_ID, home=tmp_path)
211
+        assert _discover_named_adapters(store) == []
212
+
213
+    def test_directory_size_ignores_stat_errors(self, tmp_path: Path) -> None:
214
+        path = tmp_path / "root"
215
+        path.mkdir()
216
+        good = path / "good.bin"
217
+        good.write_bytes(b"1234")
218
+
219
+        class _BadPath:
220
+            def is_file(self) -> bool:
221
+                return True
222
+
223
+            def stat(self):  # type: ignore[no-untyped-def]
224
+                raise OSError("transient")
225
+
226
+        monkeypatch = pytest.MonkeyPatch()
227
+        monkeypatch.setattr(Path, "rglob", lambda self, _pattern: iter([good, _BadPath()]))
228
+        try:
229
+            assert _directory_size(path) == 4
230
+        finally:
231
+            monkeypatch.undo()
232
+
233
+    def test_discover_named_adapters_tolerates_pointer_probe_errors(self, tmp_path: Path) -> None:
234
+        store = for_dlm(VALID_ID, home=tmp_path)
235
+        store.ensure_layout()
236
+        named = store.adapter / "knowledge"
237
+        (named / "versions" / "v0002").mkdir(parents=True)
238
+
239
+        def _boom(_name: str) -> None:
240
+            raise OSError("pointer unreadable")
241
+
242
+        monkeypatch = pytest.MonkeyPatch()
243
+        monkeypatch.setattr(
244
+            StorePath, "resolve_current_adapter_for", lambda self, _name: _boom(_name)
245
+        )
246
+        try:
247
+            states = _discover_named_adapters(store)
248
+        finally:
249
+            monkeypatch.undo()
250
+
251
+        assert states == [type(states[0])(name="knowledge", has_current=False, latest_version=2)]
252
+
253
+    def test_max_version_ignores_non_version_entries(self, tmp_path: Path) -> None:
254
+        versions = tmp_path / "versions"
255
+        versions.mkdir()
256
+        (versions / "v0002").mkdir()
257
+        (versions / "vbad").mkdir()
258
+        (versions / "notes").mkdir()
259
+        (versions / "v0009").write_text("not a dir", encoding="utf-8")
260
+        assert _max_version(versions) == 2
tests/unit/store/test_lock.pymodified
70 lines changed — click to load
@@ -2,6 +2,7 @@
22
 
33
 from __future__ import annotations
44
 
5
+import errno
56
 import multiprocessing
67
 import os
78
 import time
@@ -16,7 +17,7 @@ from dlm.store.errors import LockHeldError, StaleLockError
1617
 # Module-level worker fns so `spawn` context can pickle them.
1718
 
1819
 
19
-def _child_attempt(path: str, queue: _MPQueue) -> None:
20
+def _child_attempt(path: str, queue: _MPQueue[str]) -> None:
2021
     try:
2122
         with lock.exclusive(Path(path), timeout=0.0):
2223
             queue.put("acquired")
@@ -64,7 +65,7 @@ class TestMutualExclusion:
6465
         # Parent acquires and releases in sequence; we hold inside the `with`
6566
         # block long enough for the child to try + fail.
6667
         ctx = multiprocessing.get_context("spawn")
67
-        outcome: _MPQueue = ctx.Queue()
68
+        outcome: _MPQueue[str] = ctx.Queue()
6869
 
6970
         with lock.exclusive(lock_path):
7071
             proc = ctx.Process(target=_child_attempt, args=(str(lock_path), outcome))
@@ -121,3 +122,47 @@ class TestStaleLock:
121122
         with pytest.raises(StaleLockError) as exc, lock.exclusive(lock_path, timeout=0.0):
122123
             pass
123124
         assert exc.value.holder_pid is None
125
+
126
+    def test_transient_empty_lockfile_retries_when_timeout_allows(
127
+        self,
128
+        monkeypatch: pytest.MonkeyPatch,
129
+        tmp_path: Path,
130
+    ) -> None:
131
+        lock_path = tmp_path / "test.lock"
132
+        payload = lock.LockInfo(pid=os.getpid(), hostname="host", acquired_at=time.time())
133
+        acquire_results = iter([False, True])
134
+        read_results = iter([None, payload])
135
+
136
+        monkeypatch.setattr(lock, "_acquire_once", lambda _path: next(acquire_results))
137
+        monkeypatch.setattr(lock, "_read_lock", lambda _path: next(read_results))
138
+        monkeypatch.setattr(lock, "_release", lambda _path: None)
139
+
140
+        with lock.exclusive(lock_path, timeout=1.0, poll_interval=0.0) as info:
141
+            assert info == payload
142
+
143
+
144
+class TestProcessProbeEdges:
145
+    def test_permission_error_treated_as_alive(self, monkeypatch: pytest.MonkeyPatch) -> None:
146
+        monkeypatch.setattr(
147
+            "dlm.store.lock.os.kill",
148
+            lambda _pid, _sig: (_ for _ in ()).throw(PermissionError()),
149
+        )
150
+        assert lock._is_alive(123) is True
151
+
152
+    def test_generic_oserror_treated_as_alive(self, monkeypatch: pytest.MonkeyPatch) -> None:
153
+        def _raise(_pid: int, _sig: int) -> None:
154
+            err = OSError("io")
155
+            err.errno = errno.EIO
156
+            raise err
157
+
158
+        monkeypatch.setattr("dlm.store.lock.os.kill", _raise)
159
+        assert lock._is_alive(123) is True
160
+
161
+    def test_esrch_oserror_treated_as_dead(self, monkeypatch: pytest.MonkeyPatch) -> None:
162
+        def _raise(_pid: int, _sig: int) -> None:
163
+            err = OSError("missing")
164
+            err.errno = errno.ESRCH
165
+            raise err
166
+
167
+        monkeypatch.setattr("dlm.store.lock.os.kill", _raise)
168
+        assert lock._is_alive(123) is False
tests/unit/store/test_paths.pymodified
78 lines changed — click to load
@@ -13,12 +13,17 @@ from dlm.store.layout import (
1313
     LOGS_DIR,
1414
     MANIFEST_FILENAME,
1515
 )
16
-from dlm.store.paths import StorePath, dlm_home, ensure_home, for_dlm
16
+from dlm.store.paths import StorePath, _current_os_name, dlm_home, ensure_home, for_dlm
1717
 
1818
 VALID_ID = "01HZ4X7TGZM3J1A2B3C4D5E6F7"
1919
 
2020
 
2121
 class TestDlmHome:
22
+    def test_current_os_name_passthrough(self) -> None:
23
+        import os
24
+
25
+        assert _current_os_name() == os.name
26
+
2227
     def test_override_takes_precedence(self, tmp_path: Path) -> None:
2328
         assert dlm_home(override=tmp_path / "custom") == (tmp_path / "custom").resolve()
2429
 
@@ -39,6 +44,14 @@ class TestDlmHome:
3944
         monkeypatch.setattr(Path, "home", lambda: tmp_path / "u")
4045
         assert dlm_home() == tmp_path / "u" / ".dlm"
4146
 
47
+    def test_default_on_nt_prefers_appdata(
48
+        self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path
49
+    ) -> None:
50
+        monkeypatch.delenv("DLM_HOME", raising=False)
51
+        monkeypatch.setenv("APPDATA", str(tmp_path / "AppData" / "Roaming"))
52
+        monkeypatch.setattr("dlm.store.paths._current_os_name", lambda: "nt")
53
+        assert dlm_home() == (tmp_path / "AppData" / "Roaming").resolve() / "dlm"
54
+
4255
 
4356
 class TestEnsureHome:
4457
     def test_creates_store_subdir(self, tmp_path: Path) -> None:
@@ -69,6 +82,10 @@ class TestStorePathAccessors:
6982
     def test_lock_path(self, store: StorePath) -> None:
7083
         assert store.lock.name == LOCK_FILENAME
7184
 
85
+    def test_training_state_paths(self, store: StorePath) -> None:
86
+        assert store.training_state.name == "training_state.pt"
87
+        assert store.training_state_sha.name == "training_state.pt.sha256"
88
+
7289
     def test_adapter_subpaths(self, store: StorePath) -> None:
7390
         assert store.adapter.name == ADAPTER_DIR
7491
         assert store.adapter_versions.parent == store.adapter
@@ -78,6 +95,10 @@ class TestStorePathAccessors:
7895
     def test_logs_dir(self, store: StorePath) -> None:
7996
         assert store.logs.name == LOGS_DIR
8097
 
98
+    def test_replay_paths(self, store: StorePath) -> None:
99
+        assert store.replay_corpus.name == "corpus.zst"
100
+        assert store.replay_index.name == "index.json"
101
+
81102
     def test_adapter_version_zero_rejected(self, store: StorePath) -> None:
82103
         with pytest.raises(ValueError, match="1-indexed"):
83104
             store.adapter_version(0)
@@ -104,12 +125,26 @@ class TestStorePathAccessors:
104125
         assert store.vl_cache_dir.name == "vl-cache"
105126
         assert store.vl_cache_dir.parent == store.root
106127
 
128
+    def test_other_lazy_dirs(self, store: StorePath) -> None:
129
+        assert store.tokenized_cache_dir.name == "tokenized-cache"
130
+        assert store.audio_cache_dir.name == "audio-cache"
131
+        assert store.audio_waveform_cache_dir.name == "audio-waveform-cache"
132
+        assert store.controls_dir.name == "controls"
133
+        assert store.control_file("demo").name == "demo.safetensors"
134
+        assert store.control_meta("demo").name == "demo.meta.json"
135
+
107136
     def test_blob_and_vl_cache_lazy(self, tmp_path: Path) -> None:
108137
         sp = for_dlm(VALID_ID, home=tmp_path)
109138
         sp.ensure_layout()
110139
         assert not sp.blob_dir.exists()
111140
         assert not sp.vl_cache_dir.exists()
112141
 
142
+    def test_exists_reflects_store_root(self, tmp_path: Path) -> None:
143
+        sp = for_dlm(VALID_ID, home=tmp_path)
144
+        assert sp.exists() is False
145
+        sp.ensure_layout()
146
+        assert sp.exists() is True
147
+
113148
 
114149
 class TestEnsureLayout:
115150
     @pytest.fixture
tests/unit/synth/test_apply_pending.pyadded
337 lines changed — click to load
@@ -0,0 +1,337 @@
1
+"""Tests for synth apply/revert and pending-plan helpers."""
2
+
3
+from __future__ import annotations
4
+
5
+import json
6
+from pathlib import Path
7
+
8
+import pytest
9
+
10
+from dlm.doc.parser import ParsedDlm, parse_file, parse_text
11
+from dlm.doc.sections import Section, SectionType
12
+from dlm.doc.serializer import serialize
13
+from dlm.store.paths import for_dlm
14
+from dlm.synth.apply import (
15
+    SynthApplySkipReason,
16
+    apply_plan,
17
+    build_apply_plan,
18
+    render_apply_plan,
19
+    revert_all_auto_synth,
20
+)
21
+from dlm.synth.pending import (
22
+    PendingSynthPlanError,
23
+    _optional_float,
24
+    _optional_int,
25
+    _optional_str,
26
+    _section_from_payload,
27
+    clear_pending_plan,
28
+    load_pending_plan,
29
+    pending_plan_path,
30
+    save_pending_plan,
31
+)
32
+
33
+_DLM_ID = "01KPQ9X1000000000000000000"
34
+_FRONTMATTER = f"---\ndlm_id: {_DLM_ID}\ndlm_version: 15\nbase_model: smollm2-135m\n---\n"
35
+
36
+
37
+def _write_dlm(path: Path, body: str = "") -> None:
38
+    path.write_text(_FRONTMATTER + body, encoding="utf-8")
39
+
40
+
41
+def _auto_synth_instruction(
42
+    *,
43
+    question: str = "What does DGEMM do?",
44
+    answer: str = "It multiplies dense matrices.",
45
+    teacher: str = "self",
46
+    strategy: str = "extraction",
47
+    source_section_id: str = "0123456789abcdef",
48
+) -> Section:
49
+    return Section(
50
+        type=SectionType.INSTRUCTION,
51
+        content=f"### Q\n{question}\n### A\n{answer}",
52
+        start_line=12,
53
+        adapter="tone",
54
+        tags={"topic": "blas"},
55
+        auto_synth=True,
56
+        synth_teacher=teacher,
57
+        synth_strategy=strategy,
58
+        synth_at="2026-04-24T20:00:00Z",
59
+        source_section_id=source_section_id,
60
+    )
61
+
62
+
63
+def _authored_instruction() -> Section:
64
+    return Section(
65
+        type=SectionType.INSTRUCTION,
66
+        content="### Q\nWhat is BLAS?\n### A\nA linear algebra interface.",
67
+    )
68
+
69
+
70
+def _preference() -> Section:
71
+    return Section(
72
+        type=SectionType.PREFERENCE,
73
+        content="### Prompt\nmanual\n### Chosen\nyes\n### Rejected\nno",
74
+    )
75
+
76
+
77
+def _image() -> Section:
78
+    return Section(
79
+        type=SectionType.IMAGE,
80
+        content="A DGEMM block diagram.",
81
+        media_path="diagram.png",
82
+        media_alt="DGEMM diagram",
83
+        media_blob_sha="ab" * 32,
84
+    )
85
+
86
+
87
+class TestBuildApplyPlan:
88
+    def test_accepts_new_auto_synth_instruction(self) -> None:
89
+        parsed = parse_text(_FRONTMATTER + "prose body\n")
90
+        plan = build_apply_plan(parsed, [_auto_synth_instruction()])
91
+
92
+        assert len(plan.additions) == 1
93
+        assert plan.skipped == ()
94
+        assert plan.additions[0].section.auto_synth is True
95
+
96
+    def test_dedupes_within_input(self) -> None:
97
+        parsed = parse_text(_FRONTMATTER + "prose body\n")
98
+        section = _auto_synth_instruction()
99
+        plan = build_apply_plan(parsed, [section, section])
100
+
101
+        assert len(plan.additions) == 1
102
+        assert len(plan.skipped) == 1
103
+        assert plan.skipped[0].reason is SynthApplySkipReason.ALREADY_PRESENT
104
+
105
+    def test_skips_non_instruction_and_hand_authored(self) -> None:
106
+        parsed = parse_text(_FRONTMATTER + "prose body\n")
107
+        plan = build_apply_plan(parsed, [_preference(), _authored_instruction()])
108
+
109
+        assert plan.additions == ()
110
+        assert [skip.reason for skip in plan.skipped] == [
111
+            SynthApplySkipReason.NOT_INSTRUCTION,
112
+            SynthApplySkipReason.NOT_AUTO_SYNTH,
113
+        ]
114
+
115
+    def test_render_plan_mentions_adds_and_skips(self) -> None:
116
+        parsed = parse_text(_FRONTMATTER + "prose body\n")
117
+        plan = build_apply_plan(parsed, [_auto_synth_instruction(), _authored_instruction()])
118
+        rendered = render_apply_plan(plan)
119
+
120
+        assert "1 add, 1 skip" in rendered
121
+        assert "::instruction::" in rendered
122
+        assert "teacher=self" in rendered
123
+        assert "strategy=extraction" in rendered
124
+        assert "source=0123456789abcdef" in rendered
125
+        assert "not_auto_synth" in rendered
126
+
127
+
128
+class TestApplyPlan:
129
+    def test_writes_additions_and_preserves_body(self, tmp_path: Path) -> None:
130
+        target = tmp_path / "doc.dlm"
131
+        _write_dlm(target, "## hello\n\nkeep me\n")
132
+
133
+        parsed = parse_file(target)
134
+        plan = build_apply_plan(parsed, [_auto_synth_instruction()])
135
+        summary = apply_plan(parsed, plan, target=target)
136
+
137
+        assert summary.added == 1
138
+        assert summary.skipped == 0
139
+        assert len(summary.added_section_ids) == 1
140
+
141
+        reloaded = parse_file(target)
142
+        assert any(section.auto_synth for section in reloaded.sections)
143
+        assert any("keep me" in section.content for section in reloaded.sections)
144
+
145
+    def test_existing_document_section_is_skipped(self, tmp_path: Path) -> None:
146
+        target = tmp_path / "doc.dlm"
147
+        existing = _auto_synth_instruction()
148
+        parsed = parse_text(_FRONTMATTER, path=target)
149
+        plan = build_apply_plan(parsed, [existing])
150
+        apply_plan(parsed, plan, target=target)
151
+
152
+        reloaded = parse_file(target)
153
+        second_plan = build_apply_plan(reloaded, [existing])
154
+        assert second_plan.additions == ()
155
+        assert len(second_plan.skipped) == 1
156
+        assert second_plan.skipped[0].reason is SynthApplySkipReason.ALREADY_PRESENT
157
+
158
+
159
+class TestRevertAutoSynth:
160
+    def test_strips_only_auto_synth_instructions(self, tmp_path: Path) -> None:
161
+        target = tmp_path / "doc.dlm"
162
+        _write_dlm(target, "## hello\n\nkeep me\n")
163
+        parsed = parse_file(target)
164
+        plan = build_apply_plan(parsed, [_auto_synth_instruction()])
165
+        apply_plan(parsed, plan, target=target)
166
+
167
+        reloaded = parse_file(target)
168
+        updated = ParsedDlm(
169
+            frontmatter=reloaded.frontmatter,
170
+            sections=reloaded.sections + (_authored_instruction(), _preference()),
171
+            source_path=reloaded.source_path,
172
+        )
173
+        target.write_text(serialize(updated), encoding="utf-8")
174
+
175
+        parsed_with_all = parse_file(target)
176
+        summary = revert_all_auto_synth(parsed_with_all, target=target)
177
+
178
+        assert summary.added == 0
179
+        assert len(summary.added_section_ids) == 1
180
+
181
+        final = parse_file(target)
182
+        assert not any(section.auto_synth for section in final.sections)
183
+        assert any(section.type is SectionType.PREFERENCE for section in final.sections)
184
+        assert any(
185
+            section.type is SectionType.INSTRUCTION and not section.auto_synth
186
+            for section in final.sections
187
+        )
188
+        assert any("keep me" in section.content for section in final.sections)
189
+
190
+    def test_revert_noop_when_no_auto_synth(self, tmp_path: Path) -> None:
191
+        target = tmp_path / "doc.dlm"
192
+        _write_dlm(target, "::instruction::\n### Q\nmanual?\n### A\nyes\n")
193
+        parsed = parse_file(target)
194
+        summary = revert_all_auto_synth(parsed, target=target)
195
+
196
+        assert summary.added == 0
197
+        assert summary.added_section_ids == ()
198
+        reloaded = parse_file(target)
199
+        assert len(reloaded.sections) == len(parsed.sections)
200
+
201
+
202
+class TestPendingPlan:
203
+    def test_pending_path_round_trip_and_clear(self, tmp_path: Path) -> None:
204
+        home = tmp_path / "home"
205
+        source_path = tmp_path / "doc.dlm"
206
+        _write_dlm(source_path)
207
+        store = for_dlm(_DLM_ID, home=home)
208
+
209
+        path = pending_plan_path(store)
210
+        assert path == home / "store" / _DLM_ID / "synth" / "pending.json"
211
+
212
+        saved = save_pending_plan(
213
+            store,
214
+            source_path=source_path,
215
+            sections=[_auto_synth_instruction(), _image()],
216
+        )
217
+        raw = json.loads(path.read_text(encoding="utf-8"))
218
+        loaded = load_pending_plan(store)
219
+
220
+        assert saved.source_path == source_path.resolve()
221
+        assert saved.created_at.endswith("Z")
222
+        assert raw["schema_version"] == 1
223
+        assert raw["source_path"] == str(source_path.resolve())
224
+        assert loaded == saved
225
+        assert clear_pending_plan(store) is True
226
+        assert clear_pending_plan(store) is False
227
+        assert load_pending_plan(store) is None
228
+
229
+    def test_load_returns_none_when_plan_absent(self, tmp_path: Path) -> None:
230
+        store = for_dlm(_DLM_ID, home=tmp_path / "home")
231
+
232
+        assert load_pending_plan(store) is None
233
+
234
+    def test_load_rejects_unreadable_plan(
235
+        self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
236
+    ) -> None:
237
+        store = for_dlm(_DLM_ID, home=tmp_path / "home")
238
+        path = pending_plan_path(store)
239
+        path.parent.mkdir(parents=True, exist_ok=True)
240
+        path.write_text("{}", encoding="utf-8")
241
+
242
+        def _raise(_self: Path, *, encoding: str) -> str:
243
+            _ = encoding
244
+            raise OSError("boom")
245
+
246
+        monkeypatch.setattr(Path, "read_text", _raise)
247
+        with pytest.raises(PendingSynthPlanError, match="could not read staged synth plan"):
248
+            load_pending_plan(store)
249
+
250
+    @pytest.mark.parametrize(
251
+        ("payload", "message"),
252
+        [
253
+            (["not", "an", "object"], "must be a JSON object"),
254
+            ({"schema_version": 2}, "unsupported staged synth plan schema_version=2"),
255
+            (
256
+                {"schema_version": 1, "created_at": "2026-04-24T20:00:00Z", "sections": []},
257
+                "missing source_path",
258
+            ),
259
+            (
260
+                {"schema_version": 1, "source_path": "/tmp/doc.dlm", "sections": []},
261
+                "missing created_at",
262
+            ),
263
+            (
264
+                {
265
+                    "schema_version": 1,
266
+                    "source_path": "/tmp/doc.dlm",
267
+                    "created_at": "2026-04-24T20:00:00Z",
268
+                },
269
+                "missing sections",
270
+            ),
271
+            (
272
+                {
273
+                    "schema_version": 1,
274
+                    "source_path": "/tmp/doc.dlm",
275
+                    "created_at": "2026-04-24T20:00:00Z",
276
+                    "sections": [{"content": "oops"}],
277
+                },
278
+                "invalid section payload at index 0",
279
+            ),
280
+        ],
281
+    )
282
+    def test_load_rejects_invalid_payloads(
283
+        self, tmp_path: Path, payload: object, message: str
284
+    ) -> None:
285
+        store = for_dlm(_DLM_ID, home=tmp_path / "home")
286
+        path = pending_plan_path(store)
287
+        path.parent.mkdir(parents=True, exist_ok=True)
288
+        path.write_text(json.dumps(payload), encoding="utf-8")
289
+
290
+        with pytest.raises(PendingSynthPlanError, match=message):
291
+            load_pending_plan(store)
292
+
293
+    def test_load_rejects_invalid_json(self, tmp_path: Path) -> None:
294
+        store = for_dlm(_DLM_ID, home=tmp_path / "home")
295
+        path = pending_plan_path(store)
296
+        path.parent.mkdir(parents=True, exist_ok=True)
297
+        path.write_text("{not json", encoding="utf-8")
298
+
299
+        with pytest.raises(PendingSynthPlanError, match="staged synth plan is not valid JSON"):
300
+            load_pending_plan(store)
301
+
302
+
303
+class TestPendingPayloadHelpers:
304
+    def test_section_from_payload_validates_tags_and_optional_types(self) -> None:
305
+        with pytest.raises(TypeError, match="expected object, got list"):
306
+            _section_from_payload([])
307
+
308
+        with pytest.raises(TypeError, match="tags must be an object"):
309
+            _section_from_payload({"type": "instruction", "content": "x", "tags": []})
310
+
311
+        with pytest.raises(TypeError, match="tags keys and values must be strings"):
312
+            _section_from_payload({"type": "instruction", "content": "x", "tags": {"topic": 1}})
313
+
314
+        with pytest.raises(TypeError, match="expected float or null"):
315
+            _section_from_payload(
316
+                {"type": "instruction", "content": "x", "judge_score_chosen": True}
317
+            )
318
+
319
+        with pytest.raises(TypeError, match="expected int or null"):
320
+            _section_from_payload({"type": "instruction", "content": "x", "mined_run_id": True})
321
+
322
+    def test_optional_helpers_accept_none_and_reject_wrong_types(self) -> None:
323
+        assert _optional_str(None) is None
324
+        assert _optional_str("ok") == "ok"
325
+        assert _optional_float(None) is None
326
+        assert _optional_float(1) == 1.0
327
+        assert _optional_int(None) is None
328
+        assert _optional_int(7) == 7
329
+
330
+        with pytest.raises(TypeError, match="expected string or null"):
331
+            _optional_str(7)
332
+
333
+        with pytest.raises(TypeError, match="expected float or null"):
334
+            _optional_float(True)
335
+
336
+        with pytest.raises(TypeError, match="expected int or null"):
337
+            _optional_int(True)
tests/unit/synth/test_filter.pymodified
63 lines changed — click to load
@@ -53,6 +53,12 @@ class StubJudge:
5353
 
5454
 
5555
 class TestFilterSynthPlan:
56
+    def test_negative_threshold_is_rejected(self) -> None:
57
+        raw = SynthRunPlan(additions=(_planned(),), skipped=())
58
+
59
+        with pytest.raises(ValueError, match="threshold must be >= 0.0"):
60
+            filter_synth_plan(raw, filter_kind="sway", judge=StubJudge({}), threshold=-0.1)
61
+
5662
     def test_none_filter_keeps_deduped_additions(self) -> None:
5763
         raw = SynthRunPlan(
5864
             additions=(
@@ -95,6 +101,28 @@ class TestFilterSynthPlan:
95101
         assert filtered.report.dedup_count == 1
96102
         assert filtered.report.accepted_count == 1
97103
 
104
+    def test_dedup_only_removes_near_duplicates_by_similarity(self) -> None:
105
+        raw = SynthRunPlan(
106
+            additions=(
107
+                _planned(
108
+                    question="What does DGEMM compute?",
109
+                    answer="A dense matrix product.",
110
+                ),
111
+                _planned(
112
+                    source_section_id="bbbbbbbbbbbbbbbb",
113
+                    question="What does DGEMM compute",
114
+                    answer="A dense matrix product.",
115
+                ),
116
+            ),
117
+            skipped=(),
118
+        )
119
+
120
+        filtered = filter_synth_plan(raw, filter_kind="dedup-only")
121
+
122
+        assert len(filtered.additions) == 1
123
+        assert len(filtered.filtered_skipped) == 1
124
+        assert filtered.filtered_skipped[0].reason.value == "duplicate_pair"
125
+
98126
     def test_sway_filter_uses_judge_and_threshold(self) -> None:
99127
         first = _planned(question="Q1", answer="A1")
100128
         second = _planned(source_section_id="bbbbbbbbbbbbbbbb", question="Q2", answer="A2")
@@ -152,3 +180,23 @@ class TestFilterSynthPlan:
152180
         rendered = render_filter_report(filtered)
153181
 
154182
         assert "generated 1, dedup 1, judge passed 1, threshold 1" in rendered
183
+
184
+    def test_render_filter_report_for_dedup_only_mentions_filtered_entries(self) -> None:
185
+        raw = SynthRunPlan(
186
+            additions=(
187
+                _planned(question="What is DGEMM?", answer="A matrix multiply routine."),
188
+                _planned(
189
+                    source_section_id="bbbbbbbbbbbbbbbb",
190
+                    question="What is DGEMM?",
191
+                    answer="A matrix multiply routine!",
192
+                ),
193
+            ),
194
+            skipped=(),
195
+        )
196
+
197
+        filtered = filter_synth_plan(raw, filter_kind="dedup-only")
198
+        rendered = render_filter_report(filtered)
199
+
200
+        assert "generated 2, dedup 1, accepted 1" in rendered
201
+        assert "=== filtered ===" in rendered
202
+        assert "duplicate_pair" in rendered
tests/unit/synth/test_prompts.pymodified
42 lines changed — click to load
@@ -2,6 +2,8 @@
22
 
33
 from __future__ import annotations
44
 
5
+from typing import Literal, cast
6
+
57
 import pytest
68
 
79
 from dlm.synth import DEFAULT_PROMPT_TEMPLATES, SynthPromptTemplate, get_prompt_template
@@ -13,8 +15,9 @@ def test_shipped_prompt_templates_cover_both_strategies() -> None:
1315
 
1416
 @pytest.mark.parametrize("strategy", ["extraction", "expansion"])
1517
 def test_get_prompt_template_returns_shipped_template(strategy: str) -> None:
16
-    template = get_prompt_template(strategy)  # type: ignore[arg-type]
17
-    assert template is DEFAULT_PROMPT_TEMPLATES[strategy]
18
+    typed_strategy = cast(Literal["extraction", "expansion"], strategy)
19
+    template = get_prompt_template(typed_strategy)
20
+    assert template is DEFAULT_PROMPT_TEMPLATES[typed_strategy]
1821
     assert template.output_parser == "json_list"
1922
 
2023
 
@@ -25,9 +28,20 @@ def test_render_user_prompt_injects_required_values() -> None:
2528
     assert "3" in rendered
2629
 
2730
 
28
-def test_user_template_must_reference_required_variables() -> None:
29
-    with pytest.raises(ValueError, match="required variable"):
31
+@pytest.mark.parametrize(
32
+    ("template", "missing"),
33
+    [
34
+        ("Missing one variable: {{ prose }}", "['n']"),
35
+        ("Missing one variable: {{ n }}", "['prose']"),
36
+        ("Missing both variables.", "['prose', 'n']"),
37
+    ],
38
+)
39
+def test_user_template_must_reference_required_variables(
40
+    template: str,
41
+    missing: str,
42
+) -> None:
43
+    with pytest.raises(ValueError, match=missing):
3044
         SynthPromptTemplate(
3145
             system_prompt="hi",
32
-            user_template="Missing one variable: {{ prose }}",
46
+            user_template=template,
3347
         )
tests/unit/synth/test_run_dry_run.pymodified
145 lines changed — click to load
@@ -7,7 +7,7 @@ from collections import deque
77
 import pytest
88
 
99
 import dlm.synth.run as run_mod
10
-from dlm.doc.parser import parse_text
10
+from dlm.doc.parser import ParsedDlm, parse_text
1111
 from dlm.synth import SynthPromptTemplate, build_synth_plan, render_synth_plan
1212
 
1313
 _FRONTMATTER = """---
@@ -49,11 +49,37 @@ class StubTeacher:
4949
         return self._outputs.popleft()
5050
 
5151
 
52
-def _parsed(body: str):
52
+def _parsed(body: str) -> ParsedDlm:
5353
     return parse_text(_FRONTMATTER + body)
5454
 
5555
 
5656
 class TestBuildSynthPlan:
57
+    @pytest.mark.parametrize(
58
+        ("field", "value", "message"),
59
+        [
60
+            ("per_section", 0, "per_section must be >= 1"),
61
+            ("max_pairs", 0, "max_pairs must be >= 1"),
62
+            ("max_new_tokens", 0, "max_new_tokens must be >= 1"),
63
+        ],
64
+    )
65
+    def test_rejects_invalid_limits(
66
+        self,
67
+        field: str,
68
+        value: int,
69
+        message: str,
70
+    ) -> None:
71
+        parsed = _parsed("One prose block.\n")
72
+
73
+        if field == "per_section":
74
+            with pytest.raises(ValueError, match=message):
75
+                build_synth_plan(parsed, StubTeacher([]), per_section=value)
76
+        elif field == "max_pairs":
77
+            with pytest.raises(ValueError, match=message):
78
+                build_synth_plan(parsed, StubTeacher([]), max_pairs=value)
79
+        else:
80
+            with pytest.raises(ValueError, match=message):
81
+                build_synth_plan(parsed, StubTeacher([]), max_new_tokens=value)
82
+
5783
     def test_materializes_auto_synth_instruction_sections(self) -> None:
5884
         parsed = _parsed("A short prose section about matrix multiplication.\n")
5985
         teacher = StubTeacher(
@@ -192,6 +218,47 @@ class TestBuildSynthPlan:
192218
 
193219
         assert len(plan.additions) == 1
194220
         assert plan.additions[0].pair.question == "Q1"
221
+        assert len(teacher.calls) == 1
222
+
223
+    def test_max_pairs_returns_before_generating_from_later_sources(self) -> None:
224
+        parsed = _parsed(
225
+            "First prose block.\n\n"
226
+            "::instruction::\n"
227
+            "### Q\nmanual?\n"
228
+            "### A\nyes.\n\n"
229
+            "Second prose block.\n"
230
+        )
231
+        teacher = StubTeacher(
232
+            [
233
+                '[{"question":"Q1","answer":"A1"}]',
234
+                '[{"question":"Q2","answer":"A2"}]',
235
+            ]
236
+        )
237
+
238
+        plan = build_synth_plan(parsed, teacher, per_section=1, strategy="extraction", max_pairs=1)
239
+
240
+        assert len(plan.additions) == 1
241
+        assert len(teacher.calls) == 1
242
+
243
+    def test_both_strategy_skips_zero_count_branch(self) -> None:
244
+        parsed = _parsed("One prose block.\n")
245
+        teacher = StubTeacher(['[{"question":"Q1","answer":"A1"}]'])
246
+
247
+        plan = build_synth_plan(parsed, teacher, per_section=1, strategy="both")
248
+
249
+        assert len(plan.additions) == 1
250
+        assert [add.strategy for add in plan.additions] == ["extraction"]
251
+        assert len(teacher.calls) == 1
252
+
253
+    def test_expansion_strategy_uses_expansion_template(self) -> None:
254
+        parsed = _parsed("One prose block.\n")
255
+        teacher = StubTeacher(['[{"question":"Q1","answer":"A1"}]'])
256
+
257
+        plan = build_synth_plan(parsed, teacher, per_section=1, strategy="expansion")
258
+
259
+        assert len(plan.additions) == 1
260
+        assert [add.strategy for add in plan.additions] == ["expansion"]
261
+        assert "expand on the material" in teacher.calls[0][1]
195262
 
196263
 
197264
 def test_render_synth_plan_mentions_adds_and_skips() -> None:
@@ -203,3 +270,52 @@ def test_render_synth_plan_mentions_adds_and_skips() -> None:
203270
 
204271
     assert "synth plan: 0 add, 1 skip" in rendered
205272
     assert "invalid_output" in rendered
273
+
274
+
275
+def test_render_synth_plan_mentions_additions_and_truncates_long_lines() -> None:
276
+    parsed = _parsed("One prose block.\n")
277
+    long_question = "Q" * 90
278
+    long_answer = "A" * 90
279
+    teacher = StubTeacher([f'[{{"question":"{long_question}","answer":"{long_answer}"}}]'])
280
+
281
+    plan = build_synth_plan(parsed, teacher, per_section=1, strategy="extraction")
282
+    rendered = render_synth_plan(plan)
283
+
284
+    assert "synth plan: 1 add, 0 skip" in rendered
285
+    assert "+ ::instruction::" in rendered
286
+    assert "q: " in rendered
287
+    assert "a: " in rendered
288
+    assert "…" in rendered
289
+
290
+
291
+def test_first_line_returns_short_text_unchanged() -> None:
292
+    assert run_mod._first_line("short line") == "short line"
293
+
294
+
295
+@pytest.mark.parametrize(
296
+    ("raw", "message"),
297
+    [
298
+        ("[]", "teacher output produced no instruction pairs"),
299
+        ("{}", "teacher output must be a JSON list"),
300
+        ("[1]", "teacher output item 0 must be an object"),
301
+        ('[{"question":1,"answer":"ok"}]', "must contain string question/answer keys"),
302
+        ('[{"question":" ","answer":"ok"}]', "has an empty question or answer"),
303
+    ],
304
+)
305
+def test_parse_generated_pairs_rejects_bad_json_list_payloads(raw: str, message: str) -> None:
306
+    with pytest.raises(ValueError, match=message):
307
+        run_mod._parse_generated_pairs(raw, parser="json_list")
308
+
309
+
310
+@pytest.mark.parametrize(
311
+    ("raw", "message"),
312
+    [
313
+        ("Question: hi\nA: ok", "must use lines like `1. Q: ...`"),
314
+        ("1. Q: hi", "missing an answer line"),
315
+        ("1. Q: hi\nB: ok", "answers must use `A:` or `Answer:`"),
316
+        ("1. Q:   \nA: ok", "contains an empty question or answer"),
317
+    ],
318
+)
319
+def test_parse_generated_pairs_rejects_bad_numbered_list_payloads(raw: str, message: str) -> None:
320
+    with pytest.raises(ValueError, match=message):
321
+        run_mod._parse_generated_pairs(raw, parser="numbered_list")
tests/unit/synth/test_teachers.pymodified
1001 lines changed — click to load
@@ -2,9 +2,13 @@
22
 
33
 from __future__ import annotations
44
 
5
+import builtins
6
+import json
7
+import sys
8
+import urllib.error
59
 from pathlib import Path
6
-from types import SimpleNamespace
7
-from typing import Any
10
+from types import ModuleType, SimpleNamespace
11
+from typing import Any, Literal
812
 
913
 import pytest
1014
 
@@ -15,6 +19,7 @@ from dlm.synth import (
1519
     InvalidTeacherSpecError,
1620
     OpenAiTeacher,
1721
     SelfTeacher,
22
+    TeacherInvocationError,
1823
     TeacherUnavailableError,
1924
     VllmServerTeacher,
2025
     build_teacher,
@@ -22,6 +27,13 @@ from dlm.synth import (
2227
 )
2328
 
2429
 
30
+def _module(name: str, **attrs: object) -> ModuleType:
31
+    module = ModuleType(name)
32
+    for key, value in attrs.items():
33
+        setattr(module, key, value)
34
+    return module
35
+
36
+
2537
 class TestTeacherSelectorParsing:
2638
     @pytest.mark.parametrize(
2739
         ("raw", "kind", "target"),
@@ -46,6 +58,19 @@ class TestTeacherSelectorParsing:
4658
         with pytest.raises(InvalidTeacherSpecError, match="unknown teacher selector"):
4759
             parse_teacher_ref("mystery:thing")
4860
 
61
+    @pytest.mark.parametrize(
62
+        ("raw", "message"),
63
+        [
64
+            ("hf:   ", "hf teacher selector must include a model id"),
65
+            ("openai:   ", "openai teacher selector must include a model id"),
66
+            ("anthropic:   ", "anthropic teacher selector must include a model id"),
67
+            ("vllm-server:   ", "vllm-server teacher selector must include a URL"),
68
+        ],
69
+    )
70
+    def test_missing_selector_targets_are_refused(self, raw: str, message: str) -> None:
71
+        with pytest.raises(InvalidTeacherSpecError, match=message):
72
+            parse_teacher_ref(raw)
73
+
4974
 
5075
 class TestBuildTeacher:
5176
     def test_self_requires_dlm_path(self) -> None:
@@ -103,6 +128,10 @@ class TestSelfTeacher:
103128
 
104129
 
105130
 class TestHfTeacher:
131
+    def test_blank_hf_id_refused(self) -> None:
132
+        with pytest.raises(InvalidTeacherSpecError, match="must include a model id"):
133
+            HfTeacher("   ")
134
+
106135
     def test_hf_teacher_uses_loader_and_runner(self) -> None:
107136
         seen: dict[str, Any] = {}
108137
 
@@ -134,8 +163,31 @@ class TestHfTeacher:
134163
         )
135164
         assert seen["runner"][3:] == (21, 0.5, 0.8, 11)
136165
 
166
+    def test_hf_teacher_reuses_loaded_bundle(self) -> None:
167
+        loads: list[tuple[str, str]] = []
168
+
169
+        def _loader(hf_id: str, device: str) -> teachers_mod._LoadedHfTeacher:
170
+            loads.append((hf_id, device))
171
+            return teachers_mod._LoadedHfTeacher(model="model", tokenizer="tok", device=device)
172
+
173
+        teacher = HfTeacher(
174
+            "Qwen/Qwen2.5-1.5B-Instruct",
175
+            loader=_loader,
176
+            runner=lambda *_args, **_kwargs: "ok",
177
+        )
178
+
179
+        assert teacher.generate("system", "user") == "ok"
180
+        assert teacher.generate("system", "user") == "ok"
181
+        assert loads == [
182
+            ("Qwen/Qwen2.5-1.5B-Instruct", teachers_mod._resolve_generation_device("auto"))
183
+        ]
184
+
137185
 
138186
 class TestOpenAiTeacher:
187
+    def test_blank_model_refused(self) -> None:
188
+        with pytest.raises(InvalidTeacherSpecError, match="must include a model id"):
189
+            OpenAiTeacher("   ")
190
+
139191
     def test_missing_api_key_refused(self, monkeypatch: pytest.MonkeyPatch) -> None:
140192
         monkeypatch.delenv("OPENAI_API_KEY", raising=False)
141193
         teacher = OpenAiTeacher("gpt-4o-mini")
@@ -145,28 +197,59 @@ class TestOpenAiTeacher:
145197
     def test_openai_teacher_extracts_message_text(self, monkeypatch: pytest.MonkeyPatch) -> None:
146198
         monkeypatch.setenv("OPENAI_API_KEY", "secret")
147199
 
148
-        captured: dict[str, Any] = {}
200
+        payloads: list[dict[str, Any]] = []
201
+        factories: list[str] = []
149202
 
150203
         def _create(**kwargs: Any) -> Any:
151
-            captured["payload"] = kwargs
204
+            payloads.append(kwargs)
152205
             return SimpleNamespace(
153206
                 choices=[SimpleNamespace(message=SimpleNamespace(content=" generated "))]
154207
             )
155208
 
209
+        def _factory(api_key: str) -> Any:
210
+            factories.append(api_key)
211
+            return client
212
+
156213
         client = SimpleNamespace(
157214
             chat=SimpleNamespace(
158215
                 completions=SimpleNamespace(create=_create),
159216
             )
160217
         )
161218
 
162
-        teacher = OpenAiTeacher("gpt-4o-mini", client_factory=lambda api_key: client)
219
+        teacher = OpenAiTeacher(
220
+            "gpt-4o-mini",
221
+            client_factory=_factory,
222
+        )
163223
         out = teacher.generate("sys", "usr", max_new_tokens=17, temperature=0.3, top_p=0.7, seed=5)
224
+        second = teacher.generate("sys", "usr")
164225
         assert out == "generated"
165
-        assert captured["payload"]["model"] == "gpt-4o-mini"
166
-        assert captured["payload"]["seed"] == 5
226
+        assert second == "generated"
227
+        assert payloads[0]["model"] == "gpt-4o-mini"
228
+        assert payloads[0]["seed"] == 5
229
+        assert factories == ["secret"]
230
+
231
+    def test_openai_teacher_wraps_request_failures(self, monkeypatch: pytest.MonkeyPatch) -> None:
232
+        monkeypatch.setenv("OPENAI_API_KEY", "secret")
233
+
234
+        def _create(**_kwargs: Any) -> Any:
235
+            raise RuntimeError("boom")
236
+
237
+        client = SimpleNamespace(
238
+            chat=SimpleNamespace(
239
+                completions=SimpleNamespace(create=_create),
240
+            )
241
+        )
242
+        teacher = OpenAiTeacher("gpt-4o-mini", client_factory=lambda _api_key: client)
243
+
244
+        with pytest.raises(TeacherInvocationError, match="openai:gpt-4o-mini request failed: boom"):
245
+            teacher.generate("sys", "usr")
167246
 
168247
 
169248
 class TestAnthropicTeacher:
249
+    def test_blank_model_refused(self) -> None:
250
+        with pytest.raises(InvalidTeacherSpecError, match="must include a model id"):
251
+            AnthropicTeacher("   ")
252
+
170253
     def test_missing_api_key_refused(self, monkeypatch: pytest.MonkeyPatch) -> None:
171254
         monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
172255
         teacher = AnthropicTeacher("claude-3-5-haiku-latest")
@@ -176,6 +259,7 @@ class TestAnthropicTeacher:
176259
     def test_anthropic_teacher_extracts_text_blocks(self, monkeypatch: pytest.MonkeyPatch) -> None:
177260
         monkeypatch.setenv("ANTHROPIC_API_KEY", "secret")
178261
         captured: dict[str, Any] = {}
262
+        factories: list[str] = []
179263
 
180264
         class _Messages:
181265
             @staticmethod
@@ -183,6 +267,7 @@ class TestAnthropicTeacher:
183267
                 captured["payload"] = kwargs
184268
                 return SimpleNamespace(
185269
                     content=[
270
+                        SimpleNamespace(type="image", text="ignored"),
186271
                         SimpleNamespace(type="text", text=" first "),
187272
                         SimpleNamespace(type="text", text=" second "),
188273
                     ]
@@ -191,16 +276,51 @@ class TestAnthropicTeacher:
191276
         class _Client:
192277
             messages = _Messages()
193278
 
279
+        def _factory(api_key: str) -> _Client:
280
+            factories.append(api_key)
281
+            return _Client()
282
+
194283
         teacher = AnthropicTeacher(
195284
             "claude-3-5-haiku-latest",
196
-            client_factory=lambda api_key: _Client(),
285
+            client_factory=_factory,
197286
         )
198287
         out = teacher.generate("sys", "usr", max_new_tokens=19, temperature=0.2, top_p=0.6)
288
+        second = teacher.generate("sys", "usr")
199289
         assert out == "first\nsecond"
290
+        assert second == "first\nsecond"
200291
         assert captured["payload"]["model"] == "claude-3-5-haiku-latest"
292
+        assert factories == ["secret"]
293
+
294
+    def test_anthropic_teacher_wraps_request_failures(
295
+        self, monkeypatch: pytest.MonkeyPatch
296
+    ) -> None:
297
+        monkeypatch.setenv("ANTHROPIC_API_KEY", "secret")
298
+
299
+        class _Messages:
300
+            @staticmethod
301
+            def create(**_kwargs: Any) -> Any:
302
+                raise RuntimeError("boom")
303
+
304
+        class _Client:
305
+            messages = _Messages()
306
+
307
+        teacher = AnthropicTeacher(
308
+            "claude-3-5-haiku-latest",
309
+            client_factory=lambda _api_key: _Client(),
310
+        )
311
+
312
+        with pytest.raises(
313
+            TeacherInvocationError,
314
+            match="anthropic:claude-3-5-haiku-latest request failed: boom",
315
+        ):
316
+            teacher.generate("sys", "usr")
201317
 
202318
 
203319
 class TestVllmServerTeacher:
320
+    def test_blank_url_refused(self) -> None:
321
+        with pytest.raises(InvalidTeacherSpecError, match="must include a URL"):
322
+            VllmServerTeacher("   ")
323
+
204324
     def test_invalid_url_refused(self) -> None:
205325
         with pytest.raises(InvalidTeacherSpecError, match="http\\(s\\)"):
206326
             VllmServerTeacher("localhost:8000")
@@ -208,10 +328,11 @@ class TestVllmServerTeacher:
208328
     def test_vllm_teacher_queries_model_and_completion(
209329
         self, monkeypatch: pytest.MonkeyPatch
210330
     ) -> None:
211
-        calls: dict[str, Any] = {}
331
+        model_calls: list[tuple[str, float]] = []
332
+        completion_calls: list[tuple[Any, ...]] = []
212333
 
213334
         def _fake_models(base_url: str, *, request_timeout: float) -> str | None:
214
-            calls["models"] = (base_url, request_timeout)
335
+            model_calls.append((base_url, request_timeout))
215336
             return "demo-model"
216337
 
217338
         def _fake_completion(
@@ -225,15 +346,17 @@ class TestVllmServerTeacher:
225346
             seed: int | None,
226347
             request_timeout: float,
227348
         ) -> str:
228
-            calls["completion"] = (
229
-                base_url,
230
-                model_id,
231
-                messages,
232
-                max_new_tokens,
233
-                temperature,
234
-                top_p,
235
-                seed,
236
-                request_timeout,
349
+            completion_calls.append(
350
+                (
351
+                    base_url,
352
+                    model_id,
353
+                    messages,
354
+                    max_new_tokens,
355
+                    temperature,
356
+                    top_p,
357
+                    seed,
358
+                    request_timeout,
359
+                )
237360
             )
238361
             return " served "
239362
 
@@ -242,8 +365,734 @@ class TestVllmServerTeacher:
242365
 
243366
         teacher = VllmServerTeacher("http://127.0.0.1:8000")
244367
         out = teacher.generate("sys", "usr", max_new_tokens=29, temperature=0.4, top_p=0.75, seed=9)
368
+        second = teacher.generate("sys", "usr")
245369
 
246370
         assert out == "served"
247
-        assert calls["models"] == ("http://127.0.0.1:8000", 30.0)
248
-        assert calls["completion"][1] == "demo-model"
249
-        assert calls["completion"][3:] == (29, 0.4, 0.75, 9, 30.0)
371
+        assert second == "served"
372
+        assert model_calls == [("http://127.0.0.1:8000", 30.0)]
373
+        assert completion_calls[0][1] == "demo-model"
374
+        assert completion_calls[0][3:] == (29, 0.4, 0.75, 9, 30.0)
375
+
376
+
377
+class TestTeacherHelpers:
378
+    def test_flatten_teacher_prompt_handles_partial_inputs(self) -> None:
379
+        assert teachers_mod._flatten_teacher_prompt("system", "user").startswith("System:\n")
380
+        assert teachers_mod._flatten_teacher_prompt("", "user") == "user"
381
+        assert teachers_mod._flatten_teacher_prompt("system", "") == "system"
382
+
383
+    def test_require_non_empty_teacher_output_refuses_blank_text(self) -> None:
384
+        with pytest.raises(TeacherInvocationError, match="self returned empty output"):
385
+            teachers_mod._require_non_empty_teacher_output("   ", teacher="self")
386
+
387
+    def test_extract_openai_message_text_handles_list_content_and_errors(self) -> None:
388
+        response = {
389
+            "choices": [
390
+                {
391
+                    "message": {
392
+                        "content": [
393
+                            {"text": " first "},
394
+                            {"text": " second "},
395
+                        ]
396
+                    }
397
+                }
398
+            ]
399
+        }
400
+        assert teachers_mod._extract_openai_message_text(response) == "first\nsecond"
401
+
402
+        with pytest.raises(TeacherInvocationError, match="missing choices"):
403
+            teachers_mod._extract_openai_message_text({})
404
+
405
+        with pytest.raises(TeacherInvocationError, match="missing choices\\[0\\]\\.message"):
406
+            teachers_mod._extract_openai_message_text({"choices": [{}]})
407
+
408
+        with pytest.raises(TeacherInvocationError, match="missing non-empty message content"):
409
+            teachers_mod._extract_openai_message_text({"choices": [{"message": {"content": None}}]})
410
+
411
+    def test_extract_anthropic_text_handles_errors(self) -> None:
412
+        with pytest.raises(TeacherInvocationError, match="missing content blocks"):
413
+            teachers_mod._extract_anthropic_text({})
414
+
415
+        with pytest.raises(TeacherInvocationError, match="missing non-empty text blocks"):
416
+            teachers_mod._extract_anthropic_text(
417
+                {"content": [{"type": "image", "text": "ignored"}, {"type": "text", "text": "   "}]}
418
+            )
419
+
420
+    def test_normalize_chat_content_and_obj_get_helpers(self) -> None:
421
+        assert teachers_mod._normalize_chat_content(" hello ") == "hello"
422
+        assert (
423
+            teachers_mod._normalize_chat_content([{"text": " one "}, {"text": " two "}])
424
+            == "one\ntwo"
425
+        )
426
+        assert teachers_mod._normalize_chat_content([{"text": "   "}]) is None
427
+        assert teachers_mod._normalize_chat_content(123) is None
428
+        assert teachers_mod._obj_get({"name": "value"}, "name") == "value"
429
+        assert teachers_mod._obj_get(SimpleNamespace(name="value"), "name") == "value"
430
+
431
+    def test_openai_compat_url_helpers_normalize_suffixes(self) -> None:
432
+        assert (
433
+            teachers_mod._normalize_openai_compat_base_url(
434
+                "http://127.0.0.1:8000/v1/chat/completions"
435
+            )
436
+            == "http://127.0.0.1:8000"
437
+        )
438
+        assert (
439
+            teachers_mod._normalize_openai_compat_base_url("http://127.0.0.1:8000/chat/completions")
440
+            == "http://127.0.0.1:8000"
441
+        )
442
+        assert teachers_mod._openai_compat_models_url("http://127.0.0.1:8000/v1") == (
443
+            "http://127.0.0.1:8000/v1/models"
444
+        )
445
+        assert teachers_mod._openai_compat_models_url("http://127.0.0.1:8000") == (
446
+            "http://127.0.0.1:8000/v1/models"
447
+        )
448
+        assert teachers_mod._openai_compat_chat_url("http://127.0.0.1:8000/v1") == (
449
+            "http://127.0.0.1:8000/v1/chat/completions"
450
+        )
451
+        assert teachers_mod._openai_compat_chat_url("http://127.0.0.1:8000") == (
452
+            "http://127.0.0.1:8000/v1/chat/completions"
453
+        )
454
+
455
+
456
+class TestTeacherRuntimeHelpers:
457
+    def test_resolve_generation_device_prefers_requested_or_detected_backends(
458
+        self,
459
+        monkeypatch: pytest.MonkeyPatch,
460
+    ) -> None:
461
+        assert teachers_mod._resolve_generation_device("mps") == "mps"
462
+
463
+        monkeypatch.delitem(sys.modules, "torch", raising=False)
464
+        real_import = builtins.__import__
465
+
466
+        def _missing_torch(
467
+            name: str,
468
+            globals: dict[str, object] | None = None,
469
+            locals: dict[str, object] | None = None,
470
+            fromlist: tuple[str, ...] = (),
471
+            level: int = 0,
472
+        ) -> object:
473
+            if name == "torch":
474
+                raise ImportError("no torch")
475
+            return real_import(name, globals, locals, fromlist, level)
476
+
477
+        monkeypatch.setattr(builtins, "__import__", _missing_torch)
478
+        assert teachers_mod._resolve_generation_device("auto") == "cpu"
479
+
480
+        monkeypatch.setattr(builtins, "__import__", real_import)
481
+        monkeypatch.setitem(
482
+            sys.modules,
483
+            "torch",
484
+            SimpleNamespace(
485
+                cuda=SimpleNamespace(is_available=lambda: True),
486
+                backends=SimpleNamespace(mps=SimpleNamespace(is_available=lambda: False)),
487
+            ),
488
+        )
489
+        assert teachers_mod._resolve_generation_device("auto") == "cuda"
490
+
491
+        monkeypatch.setitem(
492
+            sys.modules,
493
+            "torch",
494
+            SimpleNamespace(
495
+                cuda=SimpleNamespace(is_available=lambda: False),
496
+                backends=SimpleNamespace(mps=SimpleNamespace(is_available=lambda: True)),
497
+            ),
498
+        )
499
+        assert teachers_mod._resolve_generation_device("auto") == "mps"
500
+
501
+        monkeypatch.setitem(
502
+            sys.modules,
503
+            "torch",
504
+            SimpleNamespace(
505
+                cuda=SimpleNamespace(is_available=lambda: False),
506
+                backends=SimpleNamespace(mps=SimpleNamespace(is_available=lambda: False)),
507
+            ),
508
+        )
509
+        assert teachers_mod._resolve_generation_device("auto") == "cpu"
510
+
511
+    def test_default_openai_client_validates_import_surface(
512
+        self,
513
+        monkeypatch: pytest.MonkeyPatch,
514
+    ) -> None:
515
+        def _raise_import(name: str) -> object:
516
+            raise ImportError(name)
517
+
518
+        monkeypatch.setattr("dlm.synth.teachers.importlib.import_module", _raise_import)
519
+        with pytest.raises(TeacherUnavailableError, match="requires the openai package"):
520
+            teachers_mod._default_openai_client("secret")
521
+
522
+        monkeypatch.setattr(
523
+            "dlm.synth.teachers.importlib.import_module", lambda _name: SimpleNamespace()
524
+        )
525
+        with pytest.raises(TeacherUnavailableError, match="does not expose OpenAI client"):
526
+            teachers_mod._default_openai_client("secret")
527
+
528
+        captured: list[str] = []
529
+
530
+        class _OpenAI:
531
+            def __init__(self, *, api_key: str) -> None:
532
+                captured.append(api_key)
533
+
534
+        monkeypatch.setattr(
535
+            "dlm.synth.teachers.importlib.import_module",
536
+            lambda _name: SimpleNamespace(OpenAI=_OpenAI),
537
+        )
538
+        client = teachers_mod._default_openai_client("secret")
539
+        assert isinstance(client, _OpenAI)
540
+        assert captured == ["secret"]
541
+
542
+    def test_default_anthropic_client_validates_import_surface(
543
+        self,
544
+        monkeypatch: pytest.MonkeyPatch,
545
+    ) -> None:
546
+        def _raise_import(name: str) -> object:
547
+            raise ImportError(name)
548
+
549
+        monkeypatch.setattr("dlm.synth.teachers.importlib.import_module", _raise_import)
550
+        with pytest.raises(TeacherUnavailableError, match="requires the anthropic package"):
551
+            teachers_mod._default_anthropic_client("secret")
552
+
553
+        monkeypatch.setattr(
554
+            "dlm.synth.teachers.importlib.import_module", lambda _name: SimpleNamespace()
555
+        )
556
+        with pytest.raises(TeacherUnavailableError, match="does not expose Anthropic client"):
557
+            teachers_mod._default_anthropic_client("secret")
558
+
559
+        captured: list[str] = []
560
+
561
+        class _Anthropic:
562
+            def __init__(self, *, api_key: str) -> None:
563
+                captured.append(api_key)
564
+
565
+        monkeypatch.setattr(
566
+            "dlm.synth.teachers.importlib.import_module",
567
+            lambda _name: SimpleNamespace(Anthropic=_Anthropic),
568
+        )
569
+        client = teachers_mod._default_anthropic_client("secret")
570
+        assert isinstance(client, _Anthropic)
571
+        assert captured == ["secret"]
572
+
573
+    def test_fetch_openai_compat_model_id_handles_success_empty_and_errors(
574
+        self,
575
+        monkeypatch: pytest.MonkeyPatch,
576
+    ) -> None:
577
+        class _Response:
578
+            def __init__(self, payload: object) -> None:
579
+                self._payload = payload
580
+
581
+            def __enter__(self) -> _Response:
582
+                return self
583
+
584
+            def __exit__(self, *_args: object) -> Literal[False]:
585
+                return False
586
+
587
+            def read(self) -> bytes:
588
+                return json.dumps(self._payload).encode("utf-8")
589
+
590
+        monkeypatch.setattr(
591
+            "dlm.synth.teachers.urllib.request.urlopen",
592
+            lambda *_args, **_kwargs: _Response({"data": [{"id": "demo-model"}]}),
593
+        )
594
+        assert (
595
+            teachers_mod._fetch_openai_compat_model_id(
596
+                "http://127.0.0.1:8000",
597
+                request_timeout=1.0,
598
+            )
599
+            == "demo-model"
600
+        )
601
+
602
+        monkeypatch.setattr(
603
+            "dlm.synth.teachers.urllib.request.urlopen",
604
+            lambda *_args, **_kwargs: _Response({"data": []}),
605
+        )
606
+        assert (
607
+            teachers_mod._fetch_openai_compat_model_id(
608
+                "http://127.0.0.1:8000",
609
+                request_timeout=1.0,
610
+            )
611
+            is None
612
+        )
613
+
614
+        monkeypatch.setattr(
615
+            "dlm.synth.teachers.urllib.request.urlopen",
616
+            lambda *_args, **_kwargs: _Response({"data": [{"id": "   "}]}),
617
+        )
618
+        assert (
619
+            teachers_mod._fetch_openai_compat_model_id(
620
+                "http://127.0.0.1:8000",
621
+                request_timeout=1.0,
622
+            )
623
+            is None
624
+        )
625
+
626
+        def _raise_url_error(*_args: object, **_kwargs: object) -> object:
627
+            raise urllib.error.URLError("boom")
628
+
629
+        monkeypatch.setattr("dlm.synth.teachers.urllib.request.urlopen", _raise_url_error)
630
+        with pytest.raises(TeacherUnavailableError, match="could not query models"):
631
+            teachers_mod._fetch_openai_compat_model_id(
632
+                "http://127.0.0.1:8000",
633
+                request_timeout=1.0,
634
+            )
635
+
636
+    def test_request_openai_compat_completion_handles_success_and_failures(
637
+        self,
638
+        monkeypatch: pytest.MonkeyPatch,
639
+    ) -> None:
640
+        class _Response:
641
+            def __init__(self, payload: object) -> None:
642
+                self._payload = payload
643
+
644
+            def __enter__(self) -> _Response:
645
+                return self
646
+
647
+            def __exit__(self, *_args: object) -> Literal[False]:
648
+                return False
649
+
650
+            def read(self) -> bytes:
651
+                return json.dumps(self._payload).encode("utf-8")
652
+
653
+        monkeypatch.setattr(
654
+            "dlm.synth.teachers.urllib.request.urlopen",
655
+            lambda *_args, **_kwargs: _Response(
656
+                {"choices": [{"message": {"content": [{"text": " served "}]}}]}
657
+            ),
658
+        )
659
+        assert (
660
+            teachers_mod._request_openai_compat_completion(
661
+                "http://127.0.0.1:8000",
662
+                model_id="demo-model",
663
+                messages=[{"role": "user", "content": "hello"}],
664
+                max_new_tokens=11,
665
+                temperature=0.2,
666
+                top_p=0.8,
667
+                seed=5,
668
+                request_timeout=1.0,
669
+            )
670
+            == "served"
671
+        )
672
+
673
+        monkeypatch.setattr(
674
+            "dlm.synth.teachers.urllib.request.urlopen",
675
+            lambda *_args, **_kwargs: _Response({"choices": []}),
676
+        )
677
+        with pytest.raises(TeacherInvocationError, match="response missing choices"):
678
+            teachers_mod._request_openai_compat_completion(
679
+                "http://127.0.0.1:8000",
680
+                model_id=None,
681
+                messages=[{"role": "user", "content": "hello"}],
682
+                max_new_tokens=11,
683
+                temperature=0.2,
684
+                top_p=None,
685
+                seed=None,
686
+                request_timeout=1.0,
687
+            )
688
+
689
+        monkeypatch.setattr(
690
+            "dlm.synth.teachers.urllib.request.urlopen",
691
+            lambda *_args, **_kwargs: _Response({"choices": [{}]}),
692
+        )
693
+        with pytest.raises(
694
+            TeacherInvocationError, match="response missing choices\\[0\\]\\.message"
695
+        ):
696
+            teachers_mod._request_openai_compat_completion(
697
+                "http://127.0.0.1:8000",
698
+                model_id=None,
699
+                messages=[{"role": "user", "content": "hello"}],
700
+                max_new_tokens=11,
701
+                temperature=0.2,
702
+                top_p=None,
703
+                seed=None,
704
+                request_timeout=1.0,
705
+            )
706
+
707
+        monkeypatch.setattr(
708
+            "dlm.synth.teachers.urllib.request.urlopen",
709
+            lambda *_args, **_kwargs: _Response(
710
+                {"choices": [{"message": {"content": [{"text": "   "}]}}]}
711
+            ),
712
+        )
713
+        with pytest.raises(TeacherInvocationError, match="missing non-empty message content"):
714
+            teachers_mod._request_openai_compat_completion(
715
+                "http://127.0.0.1:8000",
716
+                model_id=None,
717
+                messages=[{"role": "user", "content": "hello"}],
718
+                max_new_tokens=11,
719
+                temperature=0.2,
720
+                top_p=None,
721
+                seed=None,
722
+                request_timeout=1.0,
723
+            )
724
+
725
+        def _raise_url_error(*_args: object, **_kwargs: object) -> object:
726
+            raise urllib.error.URLError("boom")
727
+
728
+        monkeypatch.setattr("dlm.synth.teachers.urllib.request.urlopen", _raise_url_error)
729
+        with pytest.raises(TeacherInvocationError, match="request to http://127.0.0.1:8000 failed"):
730
+            teachers_mod._request_openai_compat_completion(
731
+                "http://127.0.0.1:8000",
732
+                model_id=None,
733
+                messages=[{"role": "user", "content": "hello"}],
734
+                max_new_tokens=11,
735
+                temperature=0.2,
736
+                top_p=None,
737
+                seed=None,
738
+                request_timeout=1.0,
739
+            )
740
+
741
+
742
+def _install_self_loader_modules(
743
+    monkeypatch: pytest.MonkeyPatch,
744
+    *,
745
+    manifest_exists: bool = True,
746
+    license_acceptance: object | None = "accepted",
747
+    load_manifest_error: str | None = None,
748
+    resolve_error: str | None = None,
749
+    select_error: str | None = None,
750
+    backend_load_error: str | None = None,
751
+) -> dict[str, object]:
752
+    calls: dict[str, object] = {}
753
+    spec = object()
754
+    caps = object()
755
+    parsed = SimpleNamespace(
756
+        frontmatter=SimpleNamespace(
757
+            dlm_id="01KPQ9X1000000000000000000",
758
+            base_model="smollm2-135m",
759
+        )
760
+    )
761
+    manifest = SimpleNamespace(exists=lambda: manifest_exists)
762
+    store = SimpleNamespace(manifest=manifest)
763
+
764
+    class GatedModelError(Exception):
765
+        pass
766
+
767
+    class AdapterNotFoundError(Exception):
768
+        pass
769
+
770
+    class UnsupportedBackendError(Exception):
771
+        pass
772
+
773
+    class ManifestCorruptError(Exception):
774
+        pass
775
+
776
+    class _Backend:
777
+        def load(self, spec_arg: object, store_arg: object) -> None:
778
+            calls["load"] = (spec_arg, store_arg)
779
+            if backend_load_error is not None:
780
+                raise AdapterNotFoundError(backend_load_error)
781
+
782
+    backend = _Backend()
783
+
784
+    def _resolve(base_model: str, *, accept_license: bool) -> object:
785
+        calls["resolve"] = (base_model, accept_license)
786
+        if resolve_error is not None:
787
+            raise GatedModelError(resolve_error)
788
+        return spec
789
+
790
+    def _load_manifest(_path: object) -> object:
791
+        calls["load_manifest"] = True
792
+        if load_manifest_error is not None:
793
+            raise ManifestCorruptError(load_manifest_error)
794
+        return SimpleNamespace(license_acceptance=license_acceptance)
795
+
796
+    def _select_backend(backend_name: str, capabilities: object) -> str:
797
+        calls["select_backend"] = (backend_name, capabilities)
798
+        if select_error is not None:
799
+            raise UnsupportedBackendError(select_error)
800
+        return "stub-backend"
801
+
802
+    def _build_backend(name: str, capabilities: object) -> object:
803
+        calls["build_backend"] = (name, capabilities)
804
+        return backend
805
+
806
+    monkeypatch.setitem(
807
+        sys.modules, "dlm.base_models", _module("dlm.base_models", resolve=_resolve)
808
+    )
809
+    monkeypatch.setitem(
810
+        sys.modules,
811
+        "dlm.base_models.errors",
812
+        _module("dlm.base_models.errors", GatedModelError=GatedModelError),
813
+    )
814
+    monkeypatch.setitem(
815
+        sys.modules,
816
+        "dlm.doc.parser",
817
+        _module("dlm.doc.parser", parse_file=lambda _path: parsed),
818
+    )
819
+    monkeypatch.setitem(
820
+        sys.modules,
821
+        "dlm.hardware",
822
+        _module("dlm.hardware", doctor=lambda: SimpleNamespace(capabilities=caps)),
823
+    )
824
+    monkeypatch.setitem(
825
+        sys.modules,
826
+        "dlm.inference",
827
+        _module("dlm.inference", AdapterNotFoundError=AdapterNotFoundError),
828
+    )
829
+    monkeypatch.setitem(
830
+        sys.modules,
831
+        "dlm.inference.backends",
832
+        _module(
833
+            "dlm.inference.backends", build_backend=_build_backend, select_backend=_select_backend
834
+        ),
835
+    )
836
+    monkeypatch.setitem(
837
+        sys.modules,
838
+        "dlm.inference.backends.select",
839
+        _module("dlm.inference.backends.select", UnsupportedBackendError=UnsupportedBackendError),
840
+    )
841
+    monkeypatch.setitem(
842
+        sys.modules,
843
+        "dlm.store.errors",
844
+        _module("dlm.store.errors", ManifestCorruptError=ManifestCorruptError),
845
+    )
846
+    monkeypatch.setitem(
847
+        sys.modules,
848
+        "dlm.store.manifest",
849
+        _module("dlm.store.manifest", load_manifest=_load_manifest),
850
+    )
851
+    monkeypatch.setitem(
852
+        sys.modules,
853
+        "dlm.store.paths",
854
+        _module("dlm.store.paths", for_dlm=lambda _dlm_id: store),
855
+    )
856
+
857
+    calls["caps"] = caps
858
+    calls["store"] = store
859
+    calls["spec"] = spec
860
+    calls["errors"] = {
861
+        "gated": GatedModelError,
862
+        "adapter": AdapterNotFoundError,
863
+        "unsupported": UnsupportedBackendError,
864
+        "manifest": ManifestCorruptError,
865
+    }
866
+    return calls
867
+
868
+
869
+class TestTeacherLoaderHelpers:
870
+    def test_load_self_backend_wraps_import_error(self, monkeypatch: pytest.MonkeyPatch) -> None:
871
+        real_import = builtins.__import__
872
+
873
+        def _raise_on_base_models(
874
+            name: str,
875
+            globals: dict[str, object] | None = None,
876
+            locals: dict[str, object] | None = None,
877
+            fromlist: tuple[str, ...] = (),
878
+            level: int = 0,
879
+        ) -> object:
880
+            if name.startswith("dlm.base_models"):
881
+                raise ImportError("boom")
882
+            return real_import(name, globals, locals, fromlist, level)
883
+
884
+        monkeypatch.setattr(builtins, "__import__", _raise_on_base_models)
885
+        with pytest.raises(TeacherUnavailableError, match="requires the local inference stack"):
886
+            teachers_mod._load_self_backend(Path("/tmp/doc.dlm"), "auto")
887
+
888
+    def test_load_self_backend_uses_recorded_license_acceptance(
889
+        self,
890
+        monkeypatch: pytest.MonkeyPatch,
891
+    ) -> None:
892
+        calls = _install_self_loader_modules(monkeypatch, license_acceptance="accepted")
893
+
894
+        backend = teachers_mod._load_self_backend(Path("/tmp/doc.dlm"), "auto")
895
+
896
+        assert backend is not None
897
+        assert calls["resolve"] == ("smollm2-135m", True)
898
+        assert calls["select_backend"] == ("auto", calls["caps"])
899
+        assert calls["build_backend"] == ("stub-backend", calls["caps"])
900
+        assert calls["load"] == (calls["spec"], calls["store"])
901
+
902
+    def test_load_self_backend_tolerates_manifest_read_failure(
903
+        self,
904
+        monkeypatch: pytest.MonkeyPatch,
905
+    ) -> None:
906
+        calls = _install_self_loader_modules(
907
+            monkeypatch,
908
+            load_manifest_error="bad manifest",
909
+        )
910
+
911
+        teachers_mod._load_self_backend(Path("/tmp/doc.dlm"), "auto")
912
+
913
+        assert calls["resolve"] == ("smollm2-135m", False)
914
+
915
+    def test_load_self_backend_wraps_gated_backend_and_adapter_failures(
916
+        self,
917
+        monkeypatch: pytest.MonkeyPatch,
918
+    ) -> None:
919
+        _install_self_loader_modules(monkeypatch, resolve_error="gated")
920
+        with pytest.raises(TeacherUnavailableError, match="cannot resolve gated base"):
921
+            teachers_mod._load_self_backend(Path("/tmp/doc.dlm"), "auto")
922
+
923
+        _install_self_loader_modules(monkeypatch, select_error="unsupported backend")
924
+        with pytest.raises(TeacherUnavailableError, match="unsupported backend"):
925
+            teachers_mod._load_self_backend(Path("/tmp/doc.dlm"), "auto")
926
+
927
+        _install_self_loader_modules(monkeypatch, backend_load_error="missing adapter")
928
+        with pytest.raises(TeacherUnavailableError, match="requires a trained adapter"):
929
+            teachers_mod._load_self_backend(Path("/tmp/doc.dlm"), "auto")
930
+
931
+    def test_default_hf_loader_wraps_import_error(self, monkeypatch: pytest.MonkeyPatch) -> None:
932
+        real_import = builtins.__import__
933
+
934
+        def _raise_transformers(
935
+            name: str,
936
+            globals: dict[str, object] | None = None,
937
+            locals: dict[str, object] | None = None,
938
+            fromlist: tuple[str, ...] = (),
939
+            level: int = 0,
940
+        ) -> object:
941
+            if name == "transformers":
942
+                raise ImportError("boom")
943
+            return real_import(name, globals, locals, fromlist, level)
944
+
945
+        monkeypatch.setattr(builtins, "__import__", _raise_transformers)
946
+        with pytest.raises(TeacherUnavailableError, match="requires transformers"):
947
+            teachers_mod._default_hf_loader("hf/model", "cpu")
948
+
949
+    def test_default_hf_loader_moves_model_and_sets_eval(
950
+        self,
951
+        monkeypatch: pytest.MonkeyPatch,
952
+    ) -> None:
953
+        seen: dict[str, object] = {}
954
+
955
+        class _Model:
956
+            def to(self, device: str) -> _Model:
957
+                seen["device"] = device
958
+                return self
959
+
960
+            def eval(self) -> None:
961
+                seen["eval"] = True
962
+
963
+        model = _Model()
964
+
965
+        class AutoModelForCausalLM:
966
+            @staticmethod
967
+            def from_pretrained(hf_id: str) -> _Model:
968
+                seen["model_id"] = hf_id
969
+                return model
970
+
971
+        class AutoTokenizer:
972
+            @staticmethod
973
+            def from_pretrained(hf_id: str) -> str:
974
+                seen["tokenizer_id"] = hf_id
975
+                return "tok"
976
+
977
+        monkeypatch.setitem(
978
+            sys.modules,
979
+            "transformers",
980
+            _module(
981
+                "transformers",
982
+                AutoModelForCausalLM=AutoModelForCausalLM,
983
+                AutoTokenizer=AutoTokenizer,
984
+            ),
985
+        )
986
+
987
+        loaded = teachers_mod._default_hf_loader("hf/model", "cuda")
988
+
989
+        assert loaded.model is model
990
+        assert loaded.tokenizer == "tok"
991
+        assert loaded.device == "cuda"
992
+        assert seen == {
993
+            "model_id": "hf/model",
994
+            "tokenizer_id": "hf/model",
995
+            "device": "cuda",
996
+            "eval": True,
997
+        }
998
+
999
+    def test_default_hf_generate_seeds_torch_and_calls_runner(
1000
+        self,
1001
+        monkeypatch: pytest.MonkeyPatch,
1002
+    ) -> None:
1003
+        manual: list[int] = []
1004
+        manual_all: list[int] = []
1005
+        calls: dict[str, object] = {}
1006
+
1007
+        def _generate(
1008
+            model: object,
1009
+            tokenizer: object,
1010
+            prompt: str,
1011
+            *,
1012
+            max_new_tokens: int,
1013
+            temperature: float,
1014
+            top_p: float | None,
1015
+        ) -> str:
1016
+            calls["args"] = (model, tokenizer, prompt, max_new_tokens, temperature, top_p)
1017
+            return "ok"
1018
+
1019
+        monkeypatch.setitem(
1020
+            sys.modules,
1021
+            "dlm.inference.generate",
1022
+            _module("dlm.inference.generate", generate=_generate),
1023
+        )
1024
+        monkeypatch.setitem(
1025
+            sys.modules,
1026
+            "torch",
1027
+            SimpleNamespace(
1028
+                manual_seed=lambda seed: manual.append(seed),
1029
+                cuda=SimpleNamespace(
1030
+                    is_available=lambda: True,
1031
+                    manual_seed_all=lambda seed: manual_all.append(seed),
1032
+                ),
1033
+            ),
1034
+        )
1035
+
1036
+        out = teachers_mod._default_hf_generate(
1037
+            "model",
1038
+            "tokenizer",
1039
+            "prompt",
1040
+            max_new_tokens=17,
1041
+            temperature=0.3,
1042
+            top_p=0.8,
1043
+            seed=7,
1044
+        )
1045
+
1046
+        assert out == "ok"
1047
+        assert manual == [7]
1048
+        assert manual_all == [7]
1049
+        assert calls["args"] == ("model", "tokenizer", "prompt", 17, 0.3, 0.8)
1050
+
1051
+    def test_default_hf_generate_tolerates_missing_torch_when_seeding(
1052
+        self,
1053
+        monkeypatch: pytest.MonkeyPatch,
1054
+    ) -> None:
1055
+        real_import = builtins.__import__
1056
+
1057
+        def _generate(
1058
+            model: object,
1059
+            tokenizer: object,
1060
+            prompt: str,
1061
+            *,
1062
+            max_new_tokens: int,
1063
+            temperature: float,
1064
+            top_p: float | None,
1065
+        ) -> str:
1066
+            _ = model, tokenizer, prompt, max_new_tokens, temperature, top_p
1067
+            return "ok"
1068
+
1069
+        def _raise_torch(
1070
+            name: str,
1071
+            globals: dict[str, object] | None = None,
1072
+            locals: dict[str, object] | None = None,
1073
+            fromlist: tuple[str, ...] = (),
1074
+            level: int = 0,
1075
+        ) -> object:
1076
+            if name == "torch":
1077
+                raise ImportError("no torch")
1078
+            return real_import(name, globals, locals, fromlist, level)
1079
+
1080
+        monkeypatch.setitem(
1081
+            sys.modules,
1082
+            "dlm.inference.generate",
1083
+            _module("dlm.inference.generate", generate=_generate),
1084
+        )
1085
+        monkeypatch.delitem(sys.modules, "torch", raising=False)
1086
+        monkeypatch.setattr(builtins, "__import__", _raise_torch)
1087
+
1088
+        out = teachers_mod._default_hf_generate(
1089
+            "model",
1090
+            "tokenizer",
1091
+            "prompt",
1092
+            max_new_tokens=17,
1093
+            temperature=0.3,
1094
+            top_p=0.8,
1095
+            seed=7,
1096
+        )
1097
+
1098
+        assert out == "ok"
tests/unit/templates/test_init.pymodified
40 lines changed — click to load
@@ -3,9 +3,12 @@
33
 from __future__ import annotations
44
 
55
 from pathlib import Path
6
+from types import SimpleNamespace
7
+from unittest.mock import patch
68
 
79
 import pytest
810
 
11
+from dlm.base_models import GatedModelError
912
 from dlm.doc.parser import parse_file
1013
 from dlm.templates import TemplateApplyError, TemplateNotFoundError, apply_template
1114
 
@@ -58,3 +61,28 @@ def test_apply_template_unknown_name_raises(tmp_path: Path) -> None:
5861
         apply_template("nonexistent-template", target)
5962
     # And doesn't leave a half-written file behind.
6063
     assert not target.exists()
64
+
65
+
66
+def test_apply_template_wraps_gated_model_error(tmp_path: Path) -> None:
67
+    target = tmp_path / "out.dlm"
68
+
69
+    with (
70
+        patch(
71
+            "dlm.base_models.resolve",
72
+            side_effect=GatedModelError("llama-3.2-1b", "https://example.test/license"),
73
+        ),
74
+        pytest.raises(TemplateApplyError, match="pass accept_license=True"),
75
+    ):
76
+        apply_template("coding-tutor", target)
77
+
78
+
79
+def test_apply_template_refuses_gated_base_without_acceptance_flag(tmp_path: Path) -> None:
80
+    target = tmp_path / "out.dlm"
81
+    spec = SimpleNamespace(key="llama-3.2-1b")
82
+
83
+    with (
84
+        patch("dlm.base_models.resolve", return_value=spec),
85
+        patch("dlm.base_models.license.is_gated", return_value=True),
86
+        pytest.raises(TemplateApplyError, match="uses gated base"),
87
+    ):
88
+        apply_template("coding-tutor", target)
tests/unit/templates/test_registry.pymodified
22 lines changed — click to load
@@ -73,6 +73,22 @@ def test_registry_drops_template_with_malformed_meta(tmp_path: Path) -> None:
7373
     assert list_bundled(gallery_dir=tmp_path) == []
7474
 
7575
 
76
+def test_load_template_rejects_non_mapping_meta(tmp_path: Path) -> None:
77
+    (tmp_path / "broken.dlm").write_text("---\ndlm_id: 01AAAA\nbase_model: foo\n---\n# body\n")
78
+    (tmp_path / "broken.meta.yaml").write_text("- not\n- a\n- mapping\n")
79
+
80
+    with pytest.raises(TemplateMetaError, match="meta must be a YAML mapping"):
81
+        load_template("broken", gallery_dir=tmp_path)
82
+
83
+
84
+def test_load_template_rejects_schema_invalid_meta(tmp_path: Path) -> None:
85
+    (tmp_path / "broken.dlm").write_text("---\ndlm_id: 01AAAA\nbase_model: foo\n---\n# body\n")
86
+    (tmp_path / "broken.meta.yaml").write_text("name: broken\ntitle: Broken\nsummary: hi\n")
87
+
88
+    with pytest.raises(TemplateMetaError, match="failed schema validation"):
89
+        load_template("broken", gallery_dir=tmp_path)
90
+
91
+
7692
 def test_load_template_with_mismatched_name_raises(tmp_path: Path) -> None:
7793
     (tmp_path / "fine.dlm").write_text("---\ndlm_id: 01AAAA\nbase_model: foo\n---\n# body\n")
7894
     # meta.name doesn't match the filename stem.
tests/unit/test_io_atomic.pymodified
51 lines changed — click to load
@@ -82,6 +82,10 @@ class TestNonceSuffix:
8282
 
8383
 
8484
 class TestCleanupStaleTmp:
85
+    def test_cleanup_skips_directories(self, tmp_path: Path) -> None:
86
+        (tmp_path / "nested.tmp.999.deadbeef").mkdir()
87
+        assert atomic.cleanup_stale_tmp_files(tmp_path) == []
88
+
8589
     def test_removes_only_dead_pid_tmp_files(self, tmp_path: Path) -> None:
8690
         """Legacy nonce-less tmps still get cleaned up — back-compat for
8791
         sweeps that span a pre-/post-upgrade writer on the same store."""
@@ -115,6 +119,41 @@ class TestCleanupStaleTmp:
115119
         assert atomic.cleanup_stale_tmp_files(tmp_path) == []
116120
         assert malformed.exists()
117121
 
122
+    def test_cleanup_ignores_tmp_file_removed_between_list_and_unlink(self, tmp_path: Path) -> None:
123
+        doomed = tmp_path / "name.tmp.99999999.deadbeef"
124
+        doomed.write_bytes(b"x")
125
+
126
+        real_unlink = Path.unlink
127
+
128
+        def fake_unlink(self: Path, *args: object, **kwargs: object) -> None:
129
+            if self == doomed:
130
+                raise FileNotFoundError
131
+            real_unlink(self, *args, **kwargs)
132
+
133
+        with (
134
+            patch("dlm.io.atomic._is_alive", return_value=False),
135
+            patch("pathlib.Path.unlink", autospec=True, side_effect=fake_unlink),
136
+        ):
137
+            assert atomic.cleanup_stale_tmp_files(tmp_path) == []
138
+
139
+
140
+class TestTmpPid:
141
+    def test_invalid_regex_pid_falls_back_to_none(self, tmp_path: Path) -> None:
142
+        target = tmp_path / "file.bin.tmp.1234.deadbeef"
143
+
144
+        class FakeMatch:
145
+            @staticmethod
146
+            def group(name: str) -> str:
147
+                assert name == "pid"
148
+                return "not-a-pid"
149
+
150
+        fake_pattern = type(
151
+            "FakePattern", (), {"search": staticmethod(lambda _name: FakeMatch())}
152
+        )()
153
+
154
+        with patch("dlm.io.atomic._TMP_RE", fake_pattern):
155
+            assert atomic._tmp_pid(target) is None
156
+
118157
 
119158
 class TestIsAlive:
120159
     def test_zero_or_negative_dead(self) -> None:
tests/unit/test_main.pyadded
15 lines changed — click to load
@@ -0,0 +1,15 @@
1
+"""Direct coverage for the `python -m dlm` entrypoint."""
2
+
3
+from __future__ import annotations
4
+
5
+import runpy
6
+
7
+
8
+def test_module_entrypoint_invokes_cli_main(monkeypatch) -> None:
9
+    called: list[bool] = []
10
+
11
+    monkeypatch.setattr("dlm.cli.app.main", lambda: called.append(True))
12
+
13
+    runpy.run_module("dlm", run_name="__main__")
14
+
15
+    assert called == [True]
tests/unit/test_package_init.pyadded
24 lines changed — click to load
@@ -0,0 +1,24 @@
1
+"""Direct coverage for package-level version fallback wiring."""
2
+
3
+from __future__ import annotations
4
+
5
+import runpy
6
+from importlib.metadata import PackageNotFoundError
7
+from pathlib import Path
8
+from unittest.mock import patch
9
+
10
+_INIT_PATH = Path(__file__).resolve().parents[2] / "src" / "dlm" / "__init__.py"
11
+
12
+
13
+def test_package_init_reads_installed_version() -> None:
14
+    with patch("importlib.metadata.version", return_value="1.2.3"):
15
+        module_globals = runpy.run_path(str(_INIT_PATH))
16
+
17
+    assert module_globals["__version__"] == "1.2.3"
18
+
19
+
20
+def test_package_init_falls_back_when_package_metadata_is_missing() -> None:
21
+    with patch("importlib.metadata.version", side_effect=PackageNotFoundError):
22
+        module_globals = runpy.run_path(str(_INIT_PATH))
23
+
24
+    assert module_globals["__version__"] == "0.0.0+unknown"
tests/unit/train/cpt/test_embed_warmup.pymodified
22 lines changed — click to load
@@ -35,6 +35,22 @@ def _model(*, embed_frozen: bool = True, head_frozen: bool = True, tied: bool =
3535
 
3636
 
3737
 class TestUnfreezeContextManager:
38
+    def test_missing_embedding_modules_yield_empty_list(self) -> None:
39
+        model = SimpleNamespace(
40
+            get_input_embeddings=lambda: None,
41
+            get_output_embeddings=lambda: None,
42
+        )
43
+        with unfreeze_embeddings_for(model) as weights:
44
+            assert weights == []
45
+
46
+    def test_modules_without_weight_are_skipped(self) -> None:
47
+        model = SimpleNamespace(
48
+            get_input_embeddings=lambda: SimpleNamespace(weight=None),
49
+            get_output_embeddings=lambda: SimpleNamespace(weight=None),
50
+        )
51
+        with unfreeze_embeddings_for(model) as weights:
52
+            assert weights == []
53
+
3854
     def test_unfreezes_both_embeddings(self) -> None:
3955
         model = _model(embed_frozen=True, head_frozen=True)
4056
         with unfreeze_embeddings_for(model) as weights:
tests/unit/train/distributed/test_gpus.pymodified
30 lines changed — click to load
@@ -8,6 +8,10 @@ from dlm.train.distributed.gpus import GpuSpec, UnsupportedGpuSpecError, parse_g
88
 
99
 
1010
 class TestParseGpus:
11
+    def test_none_raises_empty(self) -> None:
12
+        with pytest.raises(UnsupportedGpuSpecError, match="empty"):
13
+            parse_gpus(None)  # type: ignore[arg-type]
14
+
1115
     def test_all_case_insensitive(self) -> None:
1216
         for value in ("all", "ALL", "All"):
1317
             spec = parse_gpus(value)
@@ -35,12 +39,20 @@ class TestParseGpus:
3539
         with pytest.raises(UnsupportedGpuSpecError, match="non-integer"):
3640
             parse_gpus("0,foo,1")
3741
 
42
+    def test_empty_comma_list_rejected(self) -> None:
43
+        with pytest.raises(UnsupportedGpuSpecError, match="is empty"):
44
+            parse_gpus(", ,")
45
+
3846
     def test_malformed_scalar_rejected(self) -> None:
3947
         with pytest.raises(UnsupportedGpuSpecError, match="not `all`"):
4048
             parse_gpus("xyz")
4149
 
4250
 
4351
 class TestResolveGpuSpec:
52
+    def test_list_returns_requested_ids(self) -> None:
53
+        spec = GpuSpec(kind="list", value=(0, 2))
54
+        assert spec.resolve(device_count=4) == (0, 2)
55
+
4456
     def test_all_returns_full_range(self) -> None:
4557
         spec = GpuSpec(kind="all", value=None)
4658
         assert spec.resolve(device_count=3) == (0, 1, 2)
tests/unit/train/distributed/test_rank_env.pymodified
10 lines changed — click to load
@@ -50,6 +50,10 @@ class TestDetectRank:
5050
         monkeypatch.setenv("LOCAL_RANK", "2")
5151
         assert detect_rank() == 2
5252
 
53
+    def test_negative_rank_clamped_to_zero(self, monkeypatch: pytest.MonkeyPatch) -> None:
54
+        monkeypatch.setenv("RANK", "-3")
55
+        assert detect_rank() == 0
56
+
5357
     def test_malformed_raises(self, monkeypatch: pytest.MonkeyPatch) -> None:
5458
         monkeypatch.setenv("RANK", "oops")
5559
         with pytest.raises(ValueError, match="RANK"):
tests/unit/train/distributed/test_rank_io.pymodified
8 lines changed — click to load
@@ -108,3 +108,8 @@ class TestGatherMetrics:
108108
         out = gather_metrics(acc, {"loss": 2.0})
109109
         # mean of [2.0, 3.0] = 2.5
110110
         assert out["loss"] == pytest.approx(2.5)
111
+
112
+    def test_gather_none_falls_back_to_original_value(self) -> None:
113
+        acc = SimpleNamespace(gather_for_metrics=lambda tensor: None, is_main_process=True)
114
+        out = gather_metrics(acc, {"loss": 2.0})
115
+        assert out == {"loss": 2.0}
tests/unit/train/gate/test_module.pymodified
26 lines changed — click to load
@@ -81,3 +81,26 @@ class TestGateMetadataJson:
8181
                     "mode": "trained",
8282
                 }
8383
             )
84
+
85
+    def test_non_integer_dims_rejected(self) -> None:
86
+        with pytest.raises(GateConfigError, match="input_dim/hidden_proj_dim"):
87
+            GateMetadata.from_json(
88
+                {
89
+                    "input_dim": "8",
90
+                    "hidden_proj_dim": 4,
91
+                    "adapter_names": ["a", "b"],
92
+                    "mode": "trained",
93
+                }
94
+            )
95
+
96
+    def test_non_numeric_entropy_rejected(self) -> None:
97
+        with pytest.raises(GateConfigError, match="entropy_lambda"):
98
+            GateMetadata.from_json(
99
+                {
100
+                    "input_dim": 8,
101
+                    "hidden_proj_dim": 4,
102
+                    "adapter_names": ["a", "b"],
103
+                    "mode": "trained",
104
+                    "entropy_lambda": "high",
105
+                }
106
+            )
tests/unit/train/gate/test_orchestrator.pymodified
22 lines changed — click to load
@@ -171,6 +171,22 @@ class TestRunPostSftGate:
171171
         )
172172
         assert result is None
173173
 
174
+    def test_exactly_one_named_adapter_returns_none(self, tmp_path: Path) -> None:
175
+        parsed = _parsed((_prose("x", adapter="solo"),), gate_enabled=False, adapters=("solo",))
176
+        object.__setattr__(parsed.frontmatter.training.gate, "enabled", True)
177
+        store = StorePath(root=tmp_path)
178
+        store.ensure_layout()
179
+        recorder = MetricsRecorder(tmp_path)
180
+        result = run_post_sft_gate(
181
+            store,
182
+            parsed,
183
+            run_id=1,
184
+            recorder=recorder,
185
+            embed=lambda _p: _tensor(4),
186
+            input_dim=4,
187
+        )
188
+        assert result is None
189
+
174190
     def test_cold_start_fallback_records_uniform_events(self, tmp_path: Path) -> None:
175191
         parsed = _parsed((_prose("only-a", adapter="a"),))
176192
         store = StorePath(root=tmp_path)
tests/unit/train/gate/test_trainer.pymodified
30 lines changed — click to load
@@ -14,6 +14,7 @@ from dlm.train.gate import (
1414
     load_gate,
1515
     train_gate,
1616
 )
17
+from dlm.train.gate.errors import GateTrainingError
1718
 from dlm.train.gate.paths import gate_config_path, gate_save_path
1819
 
1920
 
@@ -162,6 +163,23 @@ class TestConvergence:
162163
         )
163164
         assert len(seen) == 7
164165
 
166
+    def test_non_finite_loss_raises(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
167
+        import torch
168
+
169
+        store = _store(tmp_path)
170
+        samples = _synthetic_samples(per_class=5, input_dim=4, seed=0)
171
+        monkeypatch.setattr(torch, "isfinite", lambda value: torch.tensor(False))
172
+        with pytest.raises(GateTrainingError, match="non-finite"):
173
+            train_gate(
174
+                store,  # type: ignore[arg-type]
175
+                samples,
176
+                adapter_names=["a", "b"],
177
+                input_dim=4,
178
+                steps=1,
179
+                cold_start_floor=1,
180
+                batch_size=4,
181
+            )
182
+
165183
 
166184
 class TestLoadGateErrors:
167185
     def test_missing_config(self, tmp_path: Path) -> None:
tests/unit/train/multi_adapter/test_orchestrator.pymodified
72 lines changed — click to load
@@ -125,6 +125,72 @@ class TestSingleAdapterPassthrough:
125125
         # Flat layout: version dir lives under adapter/versions/, not a named subdir.
126126
         assert store.adapter_version(1).is_dir()
127127
 
128
+    def test_one_named_adapter_still_passthroughs(self, tmp_path: Path) -> None:
129
+        dlm_id = "01HZ4X7TGZM3J1A2B3C4D5E6FZ"
130
+        store = _seed_store(tmp_path, dlm_id)
131
+        parsed = ParsedDlm(
132
+            frontmatter=DlmFrontmatter(
133
+                dlm_id=dlm_id,
134
+                base_model="smollm2-135m",
135
+                training=TrainingConfig(
136
+                    seed=42,
137
+                    adapters={"knowledge": AdapterConfig()},
138
+                ),
139
+            ),
140
+            sections=(
141
+                Section(type=SectionType.PROSE, content="Shared domain prose."),
142
+                Section(
143
+                    type=SectionType.INSTRUCTION,
144
+                    content="### Q\nfacts?\n### A\nfacts.",
145
+                    adapter="knowledge",
146
+                ),
147
+            ),
148
+        )
149
+        results = run_all(
150
+            store,
151
+            parsed,
152
+            BASE_MODELS["smollm2-135m"],
153
+            _plan(),
154
+            mode="fresh",
155
+            trainer_factory=_mock_trainer_factory,
156
+        )
157
+        assert len(results) == 1
158
+
159
+    def test_gate_enabled_with_one_named_adapter_still_returns_one_result(
160
+        self, tmp_path: Path
161
+    ) -> None:
162
+        dlm_id = "01HZ4X7TGZM3J1A2B3C4D5E6FY"
163
+        store = _seed_store(tmp_path, dlm_id)
164
+        parsed = ParsedDlm(
165
+            frontmatter=DlmFrontmatter(
166
+                dlm_id=dlm_id,
167
+                base_model="smollm2-135m",
168
+                training=TrainingConfig(
169
+                    seed=42,
170
+                    adapters={"knowledge": AdapterConfig()},
171
+                    gate=GateConfig(enabled=False),
172
+                ),
173
+            ),
174
+            sections=(
175
+                Section(type=SectionType.PROSE, content="Shared domain prose."),
176
+                Section(
177
+                    type=SectionType.INSTRUCTION,
178
+                    content="### Q\nfacts?\n### A\nfacts.",
179
+                    adapter="knowledge",
180
+                ),
181
+            ),
182
+        )
183
+        object.__setattr__(parsed.frontmatter.training.gate, "enabled", True)
184
+        results = run_all(
185
+            store,
186
+            parsed,
187
+            BASE_MODELS["smollm2-135m"],
188
+            _plan(),
189
+            mode="fresh",
190
+            trainer_factory=_mock_trainer_factory,
191
+        )
192
+        assert len(results) == 1
193
+
128194
 
129195
 class TestMultiAdapterOrchestration:
130196
     def test_trains_each_declared_adapter(self, tmp_path: Path) -> None:
tests/unit/train/preference/test_dpo_phase.pymodified
44 lines changed — click to load
@@ -7,11 +7,15 @@ manifest → state-sidecar without importing HF/TRL or torch.
77
 
88
 from __future__ import annotations
99
 
10
+from dataclasses import replace
1011
 from pathlib import Path
1112
 from types import SimpleNamespace
1213
 from typing import Any
1314
 from unittest.mock import MagicMock
1415
 
16
+import pytest
17
+
18
+import dlm.train.preference.dpo_phase as dpo_phase
1519
 from dlm.base_models import BASE_MODELS
1620
 from dlm.doc.parser import ParsedDlm
1721
 from dlm.doc.schema import DlmFrontmatter, PreferenceConfig, TrainingConfig
@@ -220,3 +224,29 @@ class TestRunSteps:
220224
             trainer_factory=_capturing_factory,
221225
         )
222226
         assert captured["include_auto_mined"] is False
227
+
228
+    def test_writes_lock_when_decision_requests_it(
229
+        self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
230
+    ) -> None:
231
+        store = for_dlm("01DPOTEST5", home=tmp_path)
232
+        _seed_prior_sft(store, dlm_id="01DPOTEST5")
233
+        parsed = replace(_parsed_with_preferences(), source_path=tmp_path / "doc.dlm")
234
+        persist_lock = MagicMock()
235
+
236
+        monkeypatch.setattr(
237
+            dpo_phase,
238
+            "_validate_or_abort_lock",
239
+            lambda **_kwargs: SimpleNamespace(should_write_lock=True),
240
+        )
241
+        monkeypatch.setattr(dpo_phase, "_persist_lock", persist_lock)
242
+
243
+        run(
244
+            store,
245
+            parsed,
246
+            BASE_MODELS["smollm2-135m"],
247
+            _plan(),
248
+            reference_adapter_version=1,
249
+            trainer_factory=_mock_factory,
250
+        )
251
+
252
+        persist_lock.assert_called_once()
tests/unit/train/preference/test_orpo_phase.pymodified
47 lines changed — click to load
@@ -8,11 +8,15 @@ test; audit-07 B3 closes the 0% coverage gap.
88
 
99
 from __future__ import annotations
1010
 
11
+from dataclasses import replace
1112
 from pathlib import Path
1213
 from types import SimpleNamespace
1314
 from typing import Any
1415
 from unittest.mock import MagicMock
1516
 
17
+import pytest
18
+
19
+import dlm.train.preference.orpo_phase as orpo_phase
1620
 from dlm.base_models import BASE_MODELS
1721
 from dlm.doc.parser import ParsedDlm
1822
 from dlm.doc.schema import DlmFrontmatter, PreferenceConfig, TrainingConfig
@@ -260,6 +264,32 @@ class TestRunSteps:
260264
         )
261265
         assert captured["max_steps"] == 25
262266
 
267
+    def test_writes_lock_when_decision_requests_it(
268
+        self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
269
+    ) -> None:
270
+        store = for_dlm("01ORPOTEST5", home=tmp_path)
271
+        _seed_prior_sft(store, dlm_id="01ORPOTEST5")
272
+        parsed = replace(_parsed_with_preferences(), source_path=tmp_path / "doc.dlm")
273
+        persist_lock = MagicMock()
274
+
275
+        monkeypatch.setattr(
276
+            orpo_phase,
277
+            "_validate_or_abort_lock",
278
+            lambda **_kwargs: SimpleNamespace(should_write_lock=True),
279
+        )
280
+        monkeypatch.setattr(orpo_phase, "_persist_lock", persist_lock)
281
+
282
+        run(
283
+            store,
284
+            parsed,
285
+            BASE_MODELS["smollm2-135m"],
286
+            _plan(),
287
+            reference_adapter_version=1,
288
+            trainer_factory=_mock_factory,
289
+        )
290
+
291
+        persist_lock.assert_called_once()
292
+
263293
 
264294
 class TestLockModes:
265295
     def test_ignore_mode_skips_lock_write(self, tmp_path: Path) -> None:
tests/unit/train/preference/test_phase_orchestrator.pymodified
46 lines changed — click to load
@@ -14,6 +14,7 @@ from unittest.mock import MagicMock
1414
 
1515
 import pytest
1616
 
17
+import dlm.train.preference.phase_orchestrator as phase_orchestrator
1718
 from dlm.doc.schema import PreferenceConfig
1819
 from dlm.doc.sections import Section, SectionType
1920
 from dlm.train.preference.errors import (
@@ -176,6 +177,23 @@ class TestDispatcherSftOnly:
176177
         _, kwargs = sft.call_args
177178
         assert kwargs["strict_metrics"] is True
178179
 
180
+    def test_sft_phase_forwards_world_size_to_sft_runner(self) -> None:
181
+        sft = MagicMock(return_value=_FakeRunResult(adapter_version=1))
182
+
183
+        run_phases(
184
+            store=MagicMock(),
185
+            parsed=_parsed([_prose()]),
186
+            spec=MagicMock(),
187
+            plan=MagicMock(),
188
+            phase="sft",
189
+            world_size=4,
190
+            sft_runner=sft,
191
+            dpo_runner=MagicMock(),
192
+        )
193
+
194
+        _, kwargs = sft.call_args
195
+        assert kwargs["world_size"] == 4
196
+
179197
     def test_sft_phase_skips_when_no_sft_content(self) -> None:
180198
         sft = MagicMock()
181199
         dpo = MagicMock()
@@ -378,6 +396,16 @@ class TestPhaseResult:
378396
             pr.phase = "dpo"  # type: ignore[misc]
379397
 
380398
 
399
+class TestMethodRunner:
400
+    def test_method_runner_uses_registry_resolver(self, monkeypatch: pytest.MonkeyPatch) -> None:
401
+        fake = MagicMock()
402
+        monkeypatch.setattr(
403
+            "dlm.train.preference.method_registry.resolve",
404
+            lambda method: fake if method == "orpo" else None,
405
+        )
406
+        assert phase_orchestrator._method_runner("orpo") is fake
407
+
408
+
381409
 class TestAutoEnableIntegration:
382410
     """Auto-enable: when user didn't set `enabled` and preference
383411
     content is present, DPO runs under `--phase all`."""
tests/unit/train/test_cache.pyadded
38 lines changed — click to load
@@ -0,0 +1,38 @@
1
+from __future__ import annotations
2
+
3
+import logging
4
+import os
5
+
6
+from dlm.train.cache import DISABLE_ENV_VAR, disabled_cache, is_cache_disabled, set_disable_flag
7
+
8
+
9
+class TestCacheDisableFlag:
10
+    def test_disabled_false_by_default(self, monkeypatch) -> None:
11
+        monkeypatch.delenv(DISABLE_ENV_VAR, raising=False)
12
+        assert is_cache_disabled() is False
13
+
14
+    def test_set_disable_flag_sets_env_and_logs(self, monkeypatch, caplog) -> None:
15
+        monkeypatch.delenv(DISABLE_ENV_VAR, raising=False)
16
+        with caplog.at_level(logging.INFO):
17
+            set_disable_flag("cli flag")
18
+        assert is_cache_disabled() is True
19
+        assert "tokenized cache disabled (cli flag)" in caplog.text
20
+
21
+    def test_disabled_cache_restores_missing_prior_value(self, monkeypatch) -> None:
22
+        monkeypatch.delenv(DISABLE_ENV_VAR, raising=False)
23
+        with disabled_cache("scoped test"):
24
+            assert is_cache_disabled() is True
25
+        assert DISABLE_ENV_VAR not in os.environ
26
+        assert is_cache_disabled() is False
27
+
28
+    def test_disabled_cache_restores_prior_value(self, monkeypatch) -> None:
29
+        monkeypatch.setenv(DISABLE_ENV_VAR, "0")
30
+        with disabled_cache("scoped test"):
31
+            assert is_cache_disabled() is True
32
+        assert is_cache_disabled() is False
33
+
34
+    def test_disabled_cache_preserves_existing_disabled_state(self, monkeypatch) -> None:
35
+        monkeypatch.setenv(DISABLE_ENV_VAR, "1")
36
+        with disabled_cache("already disabled"):
37
+            assert is_cache_disabled() is True
38
+        assert is_cache_disabled() is True
tests/unit/train/test_checkpoint_commit.pymodified
118 lines changed — click to load
@@ -6,6 +6,7 @@ from pathlib import Path
66
 
77
 import pytest
88
 
9
+import dlm.train.checkpoint_commit as checkpoint_commit
910
 from dlm.store.paths import for_dlm
1011
 from dlm.train.checkpoint_commit import (
1112
     _uniquify_rejected,
@@ -14,6 +15,7 @@ from dlm.train.checkpoint_commit import (
1415
     fsync_dir,
1516
     list_pending_versions,
1617
 )
18
+from dlm.train.integrity import NaNWeightsError
1719
 
1820
 
1921
 def _store(home: Path):
@@ -47,6 +49,12 @@ class TestAllocation:
4749
         v1 = allocate_next_version(store)
4850
         assert v1.name == "v0001"
4951
 
52
+    def test_ignores_non_dir_entries(self, tmp_path: Path) -> None:
53
+        store = _store(tmp_path)
54
+        (store.adapter_versions / "v0009").write_text("not a directory")
55
+        v1 = allocate_next_version(store)
56
+        assert v1.name == "v0001"
57
+
5058
 
5159
 class TestCommitVersion:
5260
     def test_happy_path_flips_current(self, tmp_path: Path) -> None:
@@ -97,6 +105,60 @@ class TestCommitVersion:
97105
         v2 = allocate_next_version(store)
98106
         assert v2.name == "v0002"
99107
 
108
+    def test_nonfinite_writer_uniquify_failure_leaves_pending(
109
+        self, tmp_path: Path, monkeypatch
110
+    ) -> None:
111
+        store = _store(tmp_path)
112
+
113
+        def bad_writer(p: Path) -> None:
114
+            (p / "weights.safetensors").write_text("bad")
115
+            raise NaNWeightsError(["adapter.lora_A"])
116
+
117
+        def boom(_: Path) -> Path:
118
+            raise RuntimeError("no rejected slot")
119
+
120
+        monkeypatch.setattr(checkpoint_commit, "_uniquify_rejected", boom)
121
+
122
+        with pytest.raises(RuntimeError, match="no rejected slot"):
123
+            commit_version(store, bad_writer)
124
+
125
+        assert store.adapter_version(1).exists()
126
+        assert store.resolve_current_adapter() is None
127
+
128
+    def test_nonfinite_writer_rename_failure_still_reraises(
129
+        self, tmp_path: Path, monkeypatch
130
+    ) -> None:
131
+        store = _store(tmp_path)
132
+
133
+        def bad_writer(p: Path) -> None:
134
+            (p / "weights.safetensors").write_text("bad")
135
+            raise NaNWeightsError(["adapter.lora_B"])
136
+
137
+        def bad_rename(self: Path, target: Path) -> Path:
138
+            raise OSError("rename blocked")
139
+
140
+        monkeypatch.setattr(Path, "rename", bad_rename)
141
+
142
+        with pytest.raises(NaNWeightsError, match="NaN/inf"):
143
+            commit_version(store, bad_writer)
144
+
145
+        assert store.adapter_version(1).exists()
146
+        assert store.resolve_current_adapter() is None
147
+
148
+    def test_nonfinite_writer_renames_to_rejected_path(self, tmp_path: Path) -> None:
149
+        store = _store(tmp_path)
150
+
151
+        def bad_writer(p: Path) -> None:
152
+            (p / "weights.safetensors").write_text("bad")
153
+            raise NaNWeightsError(["adapter.lora_B"])
154
+
155
+        with pytest.raises(NaNWeightsError, match="NaN/inf"):
156
+            commit_version(store, bad_writer)
157
+
158
+        assert not store.adapter_version(1).exists()
159
+        assert (store.adapter_versions / "v0001-rejected").exists()
160
+        assert store.resolve_current_adapter() is None
161
+
100162
 
101163
 class TestListPending:
102164
     def test_no_pending_when_all_committed(self, tmp_path: Path) -> None:
@@ -115,6 +177,19 @@ class TestListPending:
115177
         pending = list_pending_versions(store)
116178
         assert [p.name for p in pending] == [v1.name]
117179
 
180
+    def test_named_adapter_pending_versions_report_orphans(self, tmp_path: Path) -> None:
181
+        store = _store(tmp_path)
182
+        orphan = allocate_next_version(store, adapter_name="writer")
183
+        commit_version(store, lambda p: (p / "a").write_text("a"), adapter_name="writer")
184
+        pending = list_pending_versions(store, adapter_name="writer")
185
+        assert [p.name for p in pending] == [orphan.name]
186
+
187
+    def test_named_adapter_pending_versions_without_current(self, tmp_path: Path) -> None:
188
+        store = _store(tmp_path)
189
+        orphan = allocate_next_version(store, adapter_name="writer")
190
+        pending = list_pending_versions(store, adapter_name="writer")
191
+        assert pending == [orphan]
192
+
118193
 
119194
 class TestFsyncDir:
120195
     def test_fsync_no_error_on_real_dir(self, tmp_path: Path) -> None:
@@ -123,6 +198,13 @@ class TestFsyncDir:
123198
 
124199
 
125200
 class TestRejectedPathAllocation:
201
+    def test_returns_first_available_suffix(self, tmp_path: Path) -> None:
202
+        pending = tmp_path / "v0001"
203
+        pending.mkdir()
204
+        (tmp_path / "v0001-rejected").mkdir()
205
+        (tmp_path / "v0001-rejected-1").mkdir()
206
+        assert _uniquify_rejected(pending) == tmp_path / "v0001-rejected-2"
207
+
126208
     def test_raises_after_1000_collisions(self, tmp_path: Path) -> None:
127209
         pending = tmp_path / "v0001"
128210
         pending.mkdir()
tests/unit/train/test_inject.pymodified
10 lines changed — click to load
@@ -37,6 +37,10 @@ class TestQueue:
3737
         with pytest.raises(ValueError, match="capacity must be positive"):
3838
             InjectedProbeQueue(capacity=0)
3939
 
40
+    def test_capacity_property_reflects_configured_limit(self) -> None:
41
+        q = InjectedProbeQueue(capacity=8)
42
+        assert q.capacity == 8
43
+
4044
     def test_depth_reports_current(self) -> None:
4145
         q = InjectedProbeQueue(capacity=8)
4246
         assert q.depth() == 0
tests/unit/train/test_integrity.pymodified
9 lines changed — click to load
@@ -87,6 +87,9 @@ class TestAssertEvalFinite:
8787
         # "check iff eval ran", so no eval entries means nothing to check.
8888
         assert_eval_finite([{"loss": 2.0, "step": 1}, {"loss": 1.5, "step": 2}])
8989
 
90
+    def test_non_dict_entries_ignored(self) -> None:
91
+        assert_eval_finite([{"loss": 2.0, "step": 1}, "not-a-dict"])
92
+
9093
     def test_finite_eval_does_not_raise(self) -> None:
9194
         assert_eval_finite([{"eval_loss": 1.8, "step": 10}])
9295
 
tests/unit/train/test_logger.pymodified
36 lines changed — click to load
@@ -3,6 +3,7 @@
33
 from __future__ import annotations
44
 
55
 import json
6
+from dataclasses import dataclass
67
 from pathlib import Path
78
 
89
 import pytest
@@ -11,6 +12,11 @@ from dlm.train.logger import Banner, StepLogger, log_path_for
1112
 
1213
 
1314
 class TestContextManager:
15
+    def test_path_property_round_trips(self, tmp_path: Path) -> None:
16
+        path = tmp_path / "x.jsonl"
17
+        log = StepLogger(path)
18
+        assert log.path == path
19
+
1420
     def test_outside_context_raises(self, tmp_path: Path) -> None:
1521
         log = StepLogger(tmp_path / "x.jsonl")
1622
         with pytest.raises(RuntimeError, match="not open"):
@@ -94,6 +100,18 @@ class TestEventLogging:
94100
         parsed = json.loads(p.read_text().strip())
95101
         assert parsed["val_ppl"] == 4.5
96102
 
103
+    def test_dataclass_fields_are_sanitized(self, tmp_path: Path) -> None:
104
+        @dataclass
105
+        class _Payload:
106
+            step: int
107
+            note: str
108
+
109
+        p = tmp_path / "run.jsonl"
110
+        with StepLogger(p) as log:
111
+            log.log_event("custom", payload=_Payload(step=3, note="ok"))
112
+        parsed = json.loads(p.read_text().strip())
113
+        assert parsed["payload"] == {"step": 3, "note": "ok"}
114
+
97115
 
98116
 class TestLogPath:
99117
     def test_shape(self, tmp_path: Path) -> None:
tests/unit/train/test_rpc.pymodified
134 lines changed — click to load
@@ -10,6 +10,7 @@ assign (`port=0`).
1010
 from __future__ import annotations
1111
 
1212
 import json
13
+import socket
1314
 import urllib.error
1415
 import urllib.request
1516
 from collections.abc import Iterator
@@ -60,6 +61,42 @@ def _post(
6061
     return resp.status, json.loads(resp.read())
6162
 
6263
 
64
+def _raw_post(
65
+    server: ProbeRpcServer,
66
+    *,
67
+    headers: dict[str, str],
68
+    body: bytes = b"",
69
+    path: str = "/rpc",
70
+) -> tuple[int, dict[str, Any]]:
71
+    host, port = server.address
72
+    lines = [
73
+        f"POST {path} HTTP/1.1",
74
+        f"Host: {host}:{port}",
75
+        *[f"{key}: {value}" for key, value in headers.items()],
76
+        "",
77
+        "",
78
+    ]
79
+    request = "\r\n".join(lines).encode("utf-8") + body
80
+    with socket.create_connection((host, port), timeout=5.0) as sock:
81
+        sock.sendall(request)
82
+        response = b""
83
+        while b"\r\n\r\n" not in response:
84
+            response += sock.recv(4096)
85
+        head, rest = response.split(b"\r\n\r\n", 1)
86
+        header_lines = head.decode("iso-8859-1").split("\r\n")
87
+        status = int(header_lines[0].split()[1])
88
+        parsed_headers: dict[str, str] = {}
89
+        for line in header_lines[1:]:
90
+            if ":" not in line:
91
+                continue
92
+            key, value = line.split(":", 1)
93
+            parsed_headers[key.lower()] = value.strip()
94
+        content_length = int(parsed_headers.get("content-length", "0"))
95
+        while len(rest) < content_length:
96
+            rest += sock.recv(4096)
97
+    return status, json.loads(rest[:content_length].decode("utf-8"))
98
+
99
+
63100
 class TestHappyPath:
64101
     def test_inject_probe_accepted(self, server: ProbeRpcServer) -> None:
65102
         status, body = _post(
@@ -102,11 +139,65 @@ class TestMalformedPayload:
102139
         assert status == 400
103140
         assert "malformed" in body["error"].lower()
104141
 
142
+    def test_invalid_content_length_400(self, server: ProbeRpcServer) -> None:
143
+        status, body = _raw_post(
144
+            server,
145
+            headers={
146
+                "Authorization": f"Bearer {_TOKEN}",
147
+                "Content-Type": "application/json",
148
+                "Content-Length": "nope",
149
+            },
150
+        )
151
+        assert status == 400
152
+        assert "content-length" in body["error"].lower()
153
+
154
+    def test_empty_body_400(self, server: ProbeRpcServer) -> None:
155
+        status, body = _raw_post(
156
+            server,
157
+            headers={
158
+                "Authorization": f"Bearer {_TOKEN}",
159
+                "Content-Type": "application/json",
160
+                "Content-Length": "0",
161
+            },
162
+        )
163
+        assert status == 400
164
+        assert "empty body" in body["error"].lower()
165
+
166
+    def test_oversized_body_400(self, server: ProbeRpcServer) -> None:
167
+        status, body = _raw_post(
168
+            server,
169
+            headers={
170
+                "Authorization": f"Bearer {_TOKEN}",
171
+                "Content-Type": "application/json",
172
+                "Content-Length": str(70 * 1024),
173
+            },
174
+        )
175
+        assert status == 400
176
+        assert "exceeds" in body["error"].lower()
177
+
178
+    def test_payload_must_be_object(self, server: ProbeRpcServer) -> None:
179
+        status, body = _post(server, body="[]")
180
+        assert status == 400
181
+        assert "json object" in body["error"].lower()
182
+
105183
     def test_missing_prompt_400(self, server: ProbeRpcServer) -> None:
106184
         status, body = _post(server, body={"method": "inject_probe", "params": {"reference": "a"}})
107185
         assert status == 400
108186
         assert "prompt" in body["error"].lower()
109187
 
188
+    def test_params_must_be_object(self, server: ProbeRpcServer) -> None:
189
+        status, body = _post(server, body={"method": "inject_probe", "params": "bad"})
190
+        assert status == 400
191
+        assert "`params`" in body["error"]
192
+
193
+    def test_empty_reference_400(self, server: ProbeRpcServer) -> None:
194
+        status, body = _post(
195
+            server,
196
+            body={"method": "inject_probe", "params": {"prompt": "q", "reference": "   "}},
197
+        )
198
+        assert status == 400
199
+        assert "reference" in body["error"].lower()
200
+
110201
     def test_non_string_tags_400(self, server: ProbeRpcServer) -> None:
111202
         status, body = _post(
112203
             server,
@@ -165,3 +256,20 @@ class TestConstruction:
165256
     def test_empty_token_rejected(self) -> None:
166257
         with pytest.raises(ValueError, match="bearer token"):
167258
             ProbeRpcServer(host="127.0.0.1", port=0, token="", queue=InjectedProbeQueue())
259
+
260
+    def test_start_twice_rejected(self) -> None:
261
+        try:
262
+            srv = ProbeRpcServer(
263
+                host="127.0.0.1",
264
+                port=0,
265
+                token=_TOKEN,
266
+                queue=InjectedProbeQueue(),
267
+            )
268
+        except PermissionError as exc:
269
+            pytest.skip(f"loopback bind blocked on this host: {exc}")
270
+        srv.start()
271
+        try:
272
+            with pytest.raises(RuntimeError, match="already started"):
273
+                srv.start()
274
+        finally:
275
+            srv.stop()
tests/unit/train/test_state_sidecar.pymodified
94 lines changed — click to load
@@ -2,6 +2,7 @@
22
 
33
 from __future__ import annotations
44
 
5
+import builtins
56
 import hashlib
67
 import io
78
 import json
@@ -23,6 +24,8 @@ from dlm.train.state_sidecar import (
2324
     TRAINING_RUN_FILENAME,
2425
     VERSIONS_FILENAME,
2526
     TrainingState,
27
+    _decode_python_random_state,
28
+    _encode_python_random_state,
2629
     capture_runtime_versions,
2730
     load_state,
2831
     save_state,
@@ -280,6 +283,10 @@ class TestRngSidecar:
280283
         # live in the JSON sidecar.
281284
         assert "numpy_rng_state" not in payload
282285
 
286
+    def test_python_random_none_helpers_round_trip(self) -> None:
287
+        assert _encode_python_random_state(None) is None
288
+        assert _decode_python_random_state(None) is None
289
+
283290
 
284291
 class TestLegacyV1Compat:
285292
     """Audit-11 B7: one-release back-compat for pre-B7 sidecars.
@@ -334,6 +341,54 @@ class TestLegacyV1Compat:
334341
         loaded = load_state(tmp_path, runtime_versions={"torch": torch.__version__})
335342
         assert loaded["global_step"] == 10
336343
 
344
+    def test_double_failed_torch_load_raises_integrity_error(
345
+        self, tmp_path: Path, monkeypatch
346
+    ) -> None:
347
+        save_state(tmp_path, _mock_state())
348
+
349
+        calls = {"count": 0}
350
+        real_load = torch.load
351
+
352
+        def fake_load(*args: Any, **kwargs: Any) -> Any:
353
+            calls["count"] += 1
354
+            if calls["count"] == 1:
355
+                raise RuntimeError("weights-only failed")
356
+            raise RuntimeError("legacy failed")
357
+
358
+        monkeypatch.setattr(torch, "load", fake_load)
359
+        with pytest.raises(ResumeIntegrityError, match="legacy load also failed"):
360
+            load_state(tmp_path, runtime_versions={"torch": torch.__version__})
361
+        monkeypatch.setattr(torch, "load", real_load)
362
+
363
+    def test_missing_sidecar_version_defaults_rng_to_none(
364
+        self, tmp_path: Path, monkeypatch
365
+    ) -> None:
366
+        save_state(tmp_path, _mock_state())
367
+        real_load = torch.load
368
+
369
+        def fake_load(*args: Any, **kwargs: Any) -> dict[str, Any]:
370
+            return {
371
+                "optimizer_state_dict": {"lr": 1e-4},
372
+                "scheduler_state_dict": {"step": 5},
373
+                "scaler_state_dict": None,
374
+                "torch_rng_state": torch.get_rng_state(),
375
+                "cuda_rng_state": None,
376
+                "global_step": 10,
377
+                "epoch": 0.5,
378
+                "best_val_loss": 0.9,
379
+                "dlm_manifest_hash": None,
380
+                "base_model_revision": "a" * 40,
381
+                "pinned_versions": {"torch": torch.__version__},
382
+                "use_qlora": False,
383
+            }
384
+
385
+        monkeypatch.setattr(torch, "load", fake_load)
386
+        loaded = load_state(tmp_path, runtime_versions={"torch": torch.__version__})
387
+        monkeypatch.setattr(torch, "load", real_load)
388
+
389
+        assert loaded["numpy_rng_state"] is None
390
+        assert loaded["python_random_state"] is None
391
+
337392
 
338393
 class TestCaptureRuntimeVersions:
339394
     def test_torch_key_populated(self) -> None:
@@ -353,3 +408,15 @@ class TestCaptureRuntimeVersions:
353408
         reports that drove the run."""
354409
         versions = capture_runtime_versions()
355410
         assert "sway" in versions
411
+
412
+    def test_missing_import_returns_none(self, monkeypatch) -> None:
413
+        real_import = builtins.__import__
414
+
415
+        def fake_import(name: str, *args: Any, **kwargs: Any) -> Any:
416
+            if name == "bitsandbytes":
417
+                raise ImportError("forced missing package")
418
+            return real_import(name, *args, **kwargs)
419
+
420
+        monkeypatch.setattr(builtins, "__import__", fake_import)
421
+        versions = capture_runtime_versions()
422
+        assert versions["bitsandbytes"] is None
tests/unit/train/test_tokenization.pymodified
33 lines changed — click to load
@@ -12,12 +12,14 @@ from __future__ import annotations
1212
 from pathlib import Path
1313
 from typing import Any, cast
1414
 
15
+import numpy as np
1516
 import pytest
1617
 from transformers import PreTrainedTokenizerBase
1718
 
1819
 from dlm.directives.cache import TokenizedCache
1920
 from dlm.train.tokenization import (
2021
     TokenizationStats,
22
+    _as_int_list,
2123
     pretokenize_rows,
2224
 )
2325
 
@@ -232,3 +234,19 @@ class TestStatsDataclass:
232234
         )
233235
         with pytest.raises(dataclasses.FrozenInstanceError):
234236
             s.total_sections = 3  # type: ignore[misc]
237
+
238
+
239
+class TestAsIntList:
240
+    def test_numpy_batch_of_one_is_flattened(self) -> None:
241
+        arr = np.asarray([[1, 2, 3]], dtype=np.int64)
242
+        assert _as_int_list(arr) == [1, 2, 3]
243
+
244
+    def test_tolist_like_object_is_flattened(self) -> None:
245
+        class _FakeTensor:
246
+            def tolist(self) -> list[list[int]]:
247
+                return [[4, 5, 6]]
248
+
249
+        assert _as_int_list(_FakeTensor()) == [4, 5, 6]
250
+
251
+    def test_plain_iterable_falls_back_to_iteration(self) -> None:
252
+        assert _as_int_list((7, 8, 9)) == [7, 8, 9]
tests/unit/train/test_trainer_helpers.pymodified
366 lines changed — click to load
@@ -8,14 +8,36 @@ testing directly.
88
 
99
 from __future__ import annotations
1010
 
11
+import logging
1112
 from pathlib import Path
12
-
13
+from types import SimpleNamespace
14
+from typing import cast
15
+from unittest.mock import MagicMock
16
+
17
+import pytest
18
+
19
+from dlm.base_models import BASE_MODELS
20
+from dlm.directives import ExpandResult, SourceProvenance
21
+from dlm.directives.discovery import DiscoveredConfig
22
+from dlm.directives.schema import DlmTrainingConfig
23
+from dlm.doc.parser import ParsedDlm
24
+from dlm.doc.schema import DlmFrontmatter, SourceDirective, TrainingConfig
25
+from dlm.doc.sections import Section, SectionType
26
+from dlm.lock import LockDecision, LockSchemaError, Severity
27
+from dlm.replay import ChangeSet
1328
 from dlm.train.trainer import (
29
+    _append_change_set_to_replay,
1430
     _append_training_run,
31
+    _attach_dlm_trainer_callback,
32
+    _build_candidate_lock,
33
+    _compute_weight_distribution,
34
+    _expand_directives,
1535
     _maybe_float,
36
+    _maybe_record_tokenization,
1637
     _next_run_id,
1738
     _sample_replay_rows,
1839
     _utc_naive,
40
+    _validate_or_abort_lock,
1941
 )
2042
 
2143
 # --- _maybe_float -----------------------------------------------------------
@@ -55,9 +77,10 @@ class TestUtcNaive:
5577
 # --- _sample_replay_rows ----------------------------------------------------
5678
 
5779
 
58
-class _FakeChangeSet:
59
-    def __init__(self, new_count: int) -> None:
60
-        self.new = [object() for _ in range(new_count)]
80
+def _fake_change_set(new_count: int) -> ChangeSet:
81
+    return ChangeSet(
82
+        new=[Section(type=SectionType.PROSE, content=f"row {i}") for i in range(new_count)]
83
+    )
6184
 
6285
 
6386
 class _EmptyReplay:
@@ -86,7 +109,7 @@ class TestSampleReplayRows:
86109
         replay = _EmptyReplay()
87110
         out = _sample_replay_rows(
88111
             replay,  # type: ignore[arg-type]
89
-            change_set=_FakeChangeSet(5),  # type: ignore[arg-type]
112
+            change_set=_fake_change_set(5),
90113
             seed=42,
91114
             adapter_version=1,
92115
         )
@@ -96,7 +119,7 @@ class TestSampleReplayRows:
96119
         replay = _WarmReplay(entries=200)
97120
         out = _sample_replay_rows(
98121
             replay,  # type: ignore[arg-type]
99
-            change_set=_FakeChangeSet(100),  # type: ignore[arg-type]
122
+            change_set=_fake_change_set(100),
100123
             seed=42,
101124
             adapter_version=1,
102125
         )
@@ -108,7 +131,7 @@ class TestSampleReplayRows:
108131
         replay = _WarmReplay(entries=100)
109132
         _sample_replay_rows(
110133
             replay,  # type: ignore[arg-type]
111
-            change_set=_FakeChangeSet(0),  # |new| = 0 → k = max(32, 0) = 32
134
+            change_set=_fake_change_set(0),  # |new| = 0 → k = max(32, 0) = 32
112135
             seed=0,
113136
             adapter_version=1,
114137
         )
@@ -123,13 +146,13 @@ class TestSampleReplayRows:
123146
         # both sample_rows calls receive an equal-state Random instance.
124147
         _sample_replay_rows(
125148
             replay1,  # type: ignore[arg-type]
126
-            change_set=_FakeChangeSet(5),  # type: ignore[arg-type]
149
+            change_set=_fake_change_set(5),
127150
             seed=7,
128151
             adapter_version=3,
129152
         )
130153
         _sample_replay_rows(
131154
             replay2,  # type: ignore[arg-type]
132
-            change_set=_FakeChangeSet(5),  # type: ignore[arg-type]
155
+            change_set=_fake_change_set(5),
133156
             seed=7,
134157
             adapter_version=3,
135158
         )
@@ -151,6 +174,34 @@ def _bootstrap_store(tmp_path: Path) -> object:
151174
     return store
152175
 
153176
 
177
+_SOURCE_PATH_SENTINEL = object()
178
+
179
+
180
+def _parsed(
181
+    tmp_path: Path,
182
+    *,
183
+    source_path: object = _SOURCE_PATH_SENTINEL,
184
+    sections: tuple[Section, ...] | None = None,
185
+    sources: tuple[SourceDirective, ...] | None = None,
186
+) -> ParsedDlm:
187
+    resolved_source_path: Path | None
188
+    if source_path is _SOURCE_PATH_SENTINEL:
189
+        resolved_source_path = tmp_path / "doc.dlm"
190
+        resolved_source_path.write_text("placeholder .dlm body\n", encoding="utf-8")
191
+    else:
192
+        assert source_path is None or isinstance(source_path, Path)
193
+        resolved_source_path = source_path
194
+    return ParsedDlm(
195
+        frontmatter=DlmFrontmatter(
196
+            dlm_id="01HZ4X7TGZM3J1A2B3C4D5E6F7",
197
+            base_model="smollm2-135m",
198
+            training=TrainingConfig(seed=42, sources=sources),
199
+        ),
200
+        sections=sections or (Section(type=SectionType.PROSE, content="x"),),
201
+        source_path=resolved_source_path,
202
+    )
203
+
204
+
154205
 class TestNextRunId:
155206
     def test_missing_manifest_returns_1(self, tmp_path: Path) -> None:
156207
         """Edge case: manifest not yet written → fresh run."""
@@ -308,3 +359,243 @@ class TestSnapshotTrainingState:
308359
         )
309360
         assert state["scaler_state_dict"] is None
310361
         assert state["use_qlora"] is True
362
+
363
+
364
+class TestAttachDlmTrainerCallback:
365
+    def test_returns_when_trainer_has_no_add_callback(self) -> None:
366
+        _attach_dlm_trainer_callback(
367
+            trainer=SimpleNamespace(),
368
+            recorder=MagicMock(),
369
+            run_id=1,
370
+            step_logger=MagicMock(),
371
+        )
372
+
373
+    def test_warns_and_swallows_callback_attachment_errors(
374
+        self,
375
+        caplog: pytest.LogCaptureFixture,
376
+    ) -> None:
377
+        caplog.set_level(logging.WARNING, logger="dlm.train.trainer")
378
+        trainer = SimpleNamespace(add_callback=MagicMock(side_effect=RuntimeError("boom")))
379
+
380
+        _attach_dlm_trainer_callback(
381
+            trainer=trainer,
382
+            recorder=MagicMock(),
383
+            run_id=1,
384
+            step_logger=MagicMock(),
385
+        )
386
+
387
+        assert "failed to attach DlmTrainerCallback" in caplog.text
388
+
389
+
390
+class TestMaybeRecordTokenization:
391
+    def test_missing_trainer_stats_is_a_no_op(self) -> None:
392
+        recorder = MagicMock()
393
+
394
+        _maybe_record_tokenization(
395
+            recorder=recorder,
396
+            run_id=1,
397
+            trainer=SimpleNamespace(),
398
+        )
399
+
400
+        recorder.record_tokenization.assert_not_called()
401
+
402
+
403
+class TestAppendChangeSetToReplay:
404
+    def test_all_media_change_set_does_not_append(self) -> None:
405
+        replay = MagicMock()
406
+        change_set = SimpleNamespace(
407
+            new=[
408
+                Section(type=SectionType.IMAGE, content="", media_path="hero.png"),
409
+                Section(
410
+                    type=SectionType.AUDIO,
411
+                    content="",
412
+                    media_path="clip.wav",
413
+                    media_transcript="spoken transcript",
414
+                ),
415
+            ]
416
+        )
417
+
418
+        _append_change_set_to_replay(
419
+            replay,
420
+            cast(ChangeSet, change_set),
421
+            run_id=7,
422
+        )
423
+
424
+        replay.append_many.assert_not_called()
425
+
426
+
427
+class TestBuildCandidateLock:
428
+    def test_requires_source_path(self, tmp_path: Path) -> None:
429
+        parsed = _parsed(tmp_path, source_path=None)
430
+
431
+        with pytest.raises(ValueError, match="source_path is required"):
432
+            _build_candidate_lock(
433
+                parsed=parsed,
434
+                spec=BASE_MODELS["smollm2-135m"],
435
+                seed=42,
436
+                run_id=1,
437
+                versions={"torch": "2.4.0"},
438
+                determinism_class="strict",
439
+                capabilities=None,
440
+            )
441
+
442
+
443
+class TestValidateOrAbortLock:
444
+    def test_default_mode_reraises_unreadable_prior_lock(self, tmp_path: Path) -> None:
445
+        store = _bootstrap_store(tmp_path)
446
+        parsed = _parsed(tmp_path)
447
+        (store.root / "dlm.lock").write_text("{not json", encoding="utf-8")  # type: ignore[attr-defined]
448
+
449
+        with pytest.raises(LockSchemaError):
450
+            _validate_or_abort_lock(
451
+                store=store,  # type: ignore[arg-type]
452
+                parsed=parsed,
453
+                spec=BASE_MODELS["smollm2-135m"],
454
+                seed=42,
455
+                run_id=1,
456
+                versions={"torch": "2.4.0"},
457
+                determinism_class="strict",
458
+                capabilities=None,
459
+                lock_mode="default",
460
+            )
461
+
462
+    def test_logs_warning_mismatches_when_validator_allows_proceed(
463
+        self,
464
+        tmp_path: Path,
465
+        monkeypatch: pytest.MonkeyPatch,
466
+        caplog: pytest.LogCaptureFixture,
467
+    ) -> None:
468
+        import dlm.train.trainer as trainer_mod
469
+
470
+        store = _bootstrap_store(tmp_path)
471
+        parsed = _parsed(tmp_path)
472
+        decision = LockDecision(
473
+            action="proceed_with_warnings",
474
+            mismatches=[(Severity.WARN, "torch minor-version drift")],
475
+            should_write_lock=True,
476
+        )
477
+        monkeypatch.setattr(trainer_mod, "load_lock", lambda _root: object())
478
+        monkeypatch.setattr(
479
+            trainer_mod,
480
+            "validate_lock",
481
+            lambda _prior, _candidate, mode="default": decision,
482
+        )
483
+        caplog.set_level(logging.WARNING, logger="dlm.train.trainer")
484
+
485
+        got = _validate_or_abort_lock(
486
+            store=store,  # type: ignore[arg-type]
487
+            parsed=parsed,
488
+            spec=BASE_MODELS["smollm2-135m"],
489
+            seed=42,
490
+            run_id=1,
491
+            versions={"torch": "2.4.0"},
492
+            determinism_class="strict",
493
+            capabilities=None,
494
+            lock_mode="default",
495
+        )
496
+
497
+        assert got == decision
498
+        assert "dlm.lock drift: torch minor-version drift" in caplog.text
499
+
500
+
501
+class TestComputeWeightDistribution:
502
+    def test_counts_rows_when_directive_weights_are_active(self, tmp_path: Path) -> None:
503
+        parsed = _parsed(
504
+            tmp_path,
505
+            sections=(Section(type=SectionType.PROSE, content="note", tags={"kind": "note"}),),
506
+        )
507
+        discovered = (
508
+            DiscoveredConfig(
509
+                anchor=tmp_path,
510
+                config=DlmTrainingConfig(weights={"kind": {"note": 2.0}}),
511
+                ignore_rules=(),
512
+            ),
513
+        )
514
+
515
+        dist = _compute_weight_distribution(parsed=parsed, directive_discovered=discovered)
516
+
517
+        assert dist == {"kind": {"note": 1}}
518
+
519
+
520
+class TestExpandDirectives:
521
+    def test_returns_original_parsed_when_expansion_finds_no_sections(
522
+        self,
523
+        tmp_path: Path,
524
+        monkeypatch: pytest.MonkeyPatch,
525
+    ) -> None:
526
+        parsed = _parsed(
527
+            tmp_path,
528
+            sources=(SourceDirective(path="corpus"),),
529
+        )
530
+        discovered = (
531
+            DiscoveredConfig(
532
+                anchor=tmp_path,
533
+                config=DlmTrainingConfig(),
534
+                ignore_rules=(),
535
+            ),
536
+        )
537
+
538
+        def _fake_expand_sources(
539
+            parsed_arg: ParsedDlm,
540
+            *,
541
+            base_path: Path,
542
+        ) -> ExpandResult:
543
+            assert parsed_arg is parsed
544
+            assert parsed.source_path is not None
545
+            assert base_path == parsed.source_path.parent
546
+            return ExpandResult(
547
+                sections=(),
548
+                provenance=(SourceProvenance(path="corpus", file_count=0, total_bytes=0),),
549
+                discovered=discovered,
550
+            )
551
+
552
+        monkeypatch.setattr("dlm.directives.expand_sources", _fake_expand_sources)
553
+
554
+        new_parsed, provenance, got_discovered = _expand_directives(parsed)
555
+
556
+        assert new_parsed is parsed
557
+        assert provenance[0].file_count == 0
558
+        assert got_discovered == discovered
559
+
560
+    def test_falls_back_to_cwd_and_logs_when_sections_expand(
561
+        self,
562
+        tmp_path: Path,
563
+        monkeypatch: pytest.MonkeyPatch,
564
+        caplog: pytest.LogCaptureFixture,
565
+    ) -> None:
566
+        parsed = _parsed(
567
+            tmp_path,
568
+            source_path=None,
569
+            sources=(SourceDirective(path="corpus"),),
570
+        )
571
+        captured: dict[str, Path] = {}
572
+
573
+        def _fake_expand_sources(
574
+            parsed_arg: ParsedDlm,
575
+            *,
576
+            base_path: Path,
577
+        ) -> ExpandResult:
578
+            captured["base_path"] = base_path
579
+            assert parsed_arg is parsed
580
+            return ExpandResult(
581
+                sections=(Section(type=SectionType.PROSE, content="expanded prose"),),
582
+                provenance=(SourceProvenance(path="corpus", file_count=1, total_bytes=14),),
583
+                discovered=(
584
+                    DiscoveredConfig(
585
+                        anchor=base_path,
586
+                        config=DlmTrainingConfig(),
587
+                        ignore_rules=(),
588
+                    ),
589
+                ),
590
+            )
591
+
592
+        monkeypatch.setattr("dlm.directives.expand_sources", _fake_expand_sources)
593
+        caplog.set_level(logging.INFO, logger="dlm.train.trainer")
594
+
595
+        new_parsed, provenance, discovered = _expand_directives(parsed)
596
+
597
+        assert captured["base_path"] == Path.cwd()
598
+        assert len(new_parsed.sections) == len(parsed.sections) + 1
599
+        assert provenance[0].path == "corpus"
600
+        assert len(discovered) == 1
601
+        assert "directives: expanded 1 file(s) across 1 source(s)" in caplog.text
tests/unit/watch/test_debounce.pymodified
16 lines changed — click to load
@@ -2,6 +2,8 @@
22
 
33
 from __future__ import annotations
44
 
5
+from unittest.mock import patch
6
+
57
 import pytest
68
 
79
 from dlm.watch.debounce import Debouncer
@@ -77,3 +79,8 @@ class TestDebouncerValidation:
7779
             Debouncer(quiet_seconds=0)
7880
         with pytest.raises(ValueError, match="quiet_seconds"):
7981
             Debouncer(quiet_seconds=-0.1)
82
+
83
+    def test_default_clock_uses_time_monotonic(self) -> None:
84
+        d = Debouncer(quiet_seconds=0.4)
85
+        with patch("time.monotonic", return_value=12.5):
86
+            assert d._now() == 12.5
tests/unit/watch/test_watcher_filter.pymodified
34 lines changed — click to load
@@ -45,3 +45,34 @@ class TestFilterEvents:
4545
         target = tmp_path / "gone.dlm"
4646
         batch: set[tuple[object, str]] = {("added", str(target))}
4747
         assert filter_events_for_path(batch, target) is True
48
+
49
+    def test_target_resolve_oserror_falls_back_to_plain_string(
50
+        self, tmp_path: Path, monkeypatch
51
+    ) -> None:
52
+        target = tmp_path / "doc.dlm"
53
+        target.write_text("x")
54
+        batch: set[tuple[object, str]] = {("modified", str(target))}
55
+
56
+        monkeypatch.setattr(Path, "resolve", lambda self: (_ for _ in ()).throw(OSError("boom")))
57
+
58
+        assert filter_events_for_path(batch, target) is True
59
+
60
+    def test_raw_path_resolve_oserror_falls_back_to_raw_path_string(
61
+        self,
62
+        tmp_path: Path,
63
+        monkeypatch,
64
+    ) -> None:
65
+        target = tmp_path / "doc.dlm"
66
+        target.write_text("x")
67
+        batch: set[tuple[object, str]] = {("modified", str(target))}
68
+        real_path = Path
69
+
70
+        class BrokenPath(Path):
71
+            _flavour = real_path()._flavour  # type: ignore[attr-defined]
72
+
73
+            def resolve(self) -> Path:
74
+                raise OSError("boom")
75
+
76
+        monkeypatch.setattr("dlm.watch.watcher.Path", BrokenPath)
77
+
78
+        assert filter_events_for_path(batch, target) is True
tests/unit/watch/test_watcher_loop.pyadded
55 lines changed — click to load
@@ -0,0 +1,55 @@
1
+"""Loop-level coverage for watch_for_changes and default stream wrapper."""
2
+
3
+from __future__ import annotations
4
+
5
+from collections.abc import Iterator
6
+from pathlib import Path
7
+from types import SimpleNamespace
8
+from unittest.mock import patch
9
+
10
+import pytest
11
+
12
+from dlm.watch.errors import WatchSetupError
13
+from dlm.watch.watcher import _default_event_stream, watch_for_changes
14
+
15
+
16
+def test_default_event_stream_wraps_watchfiles_iterator(tmp_path: Path) -> None:
17
+    seen: list[tuple[str, object | None]] = []
18
+    expected = [{("modified", str(tmp_path / "doc.dlm"))}]
19
+
20
+    def fake_watch(
21
+        path: str, *, stop_event: object | None = None
22
+    ) -> Iterator[set[tuple[object, str]]]:
23
+        seen.append((path, stop_event))
24
+        yield from expected
25
+
26
+    with patch.dict("sys.modules", {"watchfiles": SimpleNamespace(watch=fake_watch)}):
27
+        batches = list(_default_event_stream(tmp_path / "doc.dlm", stop_event="stop"))
28
+
29
+    assert batches == expected
30
+    assert seen == [(str(tmp_path), "stop")]
31
+
32
+
33
+def test_watch_for_changes_requires_existing_file(tmp_path: Path) -> None:
34
+    with pytest.raises(WatchSetupError, match="does not exist"):
35
+        watch_for_changes(tmp_path / "missing.dlm", lambda: None)
36
+
37
+
38
+def test_watch_for_changes_invokes_callback_for_matching_batches(tmp_path: Path) -> None:
39
+    target = tmp_path / "doc.dlm"
40
+    target.write_text("x")
41
+    seen: list[str] = []
42
+
43
+    def event_stream(
44
+        _path: Path, *, stop_event: object | None = None
45
+    ) -> Iterator[set[tuple[object, str]]]:
46
+        assert stop_event == "stop"
47
+        yield {("modified", str(tmp_path / "other.dlm"))}
48
+        yield {("modified", str(target))}
49
+        yield {("added", str(target))}
50
+
51
+    watch_for_changes(
52
+        target, lambda: seen.append("changed"), stop_event="stop", event_stream=event_stream
53
+    )
54
+
55
+    assert seen == ["changed", "changed"]
Diff truncated: 121 files; expand each to load its hunks.