tenseleyflow/documentlanguagemodel / c078408

Browse files

Warn on degraded eval and lock fallbacks

Authored by espadonne
SHA
c078408783de7e56948078b619ead865aa21ef6e
Parents
4acddd6
Tree
b76de7a

8 changed files

StatusFile+-
M src/dlm/eval/mode_split.py 20 5
M src/dlm/eval/probes.py 48 22
M src/dlm/export/runner.py 9 7
M src/dlm/train/trainer.py 7 1
M tests/unit/eval/test_mode_split.py 29 8
M tests/unit/eval/test_probes.py 12 0
M tests/unit/export/test_runner.py 25 0
M tests/unit/train/test_lock_wiring.py 27 0
src/dlm/eval/mode_split.pymodified
@@ -14,8 +14,11 @@ tests drive the grouping logic with a mock trainer.
1414
 
1515
 from __future__ import annotations
1616
 
17
+import logging
1718
 from typing import Any
1819
 
20
+_LOG = logging.getLogger(__name__)
21
+
1922
 
2023
 def compute_val_loss_by_mode(trainer: Any, val_ds: Any) -> tuple[float | None, float | None]:
2124
     """Return `(val_loss_cpt, val_loss_sft)` from a post-train eval pass.
@@ -53,22 +56,34 @@ def compute_val_loss_by_mode(trainer: Any, val_ds: Any) -> tuple[float | None, f
5356
         elif mode == "sft":
5457
             sft_idx.append(i)
5558
 
56
-    cpt_loss = _safe_eval_loss(trainer, val_ds, cpt_idx)
57
-    sft_loss = _safe_eval_loss(trainer, val_ds, sft_idx)
59
+    cpt_loss = _safe_eval_loss(trainer, val_ds, cpt_idx, mode="cpt")
60
+    sft_loss = _safe_eval_loss(trainer, val_ds, sft_idx, mode="sft")
5861
     return (cpt_loss, sft_loss)
5962
 
6063
 
61
-def _safe_eval_loss(trainer: Any, val_ds: Any, indices: list[int]) -> float | None:
64
+def _safe_eval_loss(trainer: Any, val_ds: Any, indices: list[int], *, mode: str) -> float | None:
6265
     """Run `trainer.evaluate(eval_dataset=subset)`; return eval_loss or None."""
6366
     if not indices:
6467
         return None
6568
     try:
6669
         subset = val_ds.select(indices)
67
-    except Exception:
70
+    except (AttributeError, IndexError, TypeError, ValueError) as exc:
71
+        _LOG.warning(
72
+            "val-loss split skipped %s subset selection (%d rows): %s",
73
+            mode,
74
+            len(indices),
75
+            exc,
76
+        )
6877
         return None
6978
     try:
7079
         metrics = trainer.evaluate(eval_dataset=subset)
71
-    except Exception:
80
+    except (RuntimeError, TypeError, ValueError) as exc:
81
+        _LOG.warning(
82
+            "val-loss split skipped %s evaluation (%d rows): %s",
83
+            mode,
84
+            len(indices),
85
+            exc,
86
+        )
7287
         return None
7388
     loss = metrics.get("eval_loss") if isinstance(metrics, dict) else None
7489
     if loss is None:
src/dlm/eval/probes.pymodified
@@ -21,13 +21,16 @@ the output diff between runs is meaningful.
2121
 from __future__ import annotations
2222
 
2323
 import hashlib
24
+import logging
2425
 from dataclasses import dataclass
2526
 
26
-from dlm.data.instruction_parser import parse_instruction_body
27
+from dlm.data.errors import InstructionParseError
28
+from dlm.data.instruction_parser import QAPair, parse_instruction_body
2729
 from dlm.doc.sections import Section, SectionType
2830
 
2931
 _PROBE_MARKER = "!probe"
3032
 _PROBE_HEADER = f"### Q {_PROBE_MARKER}"
33
+_LOG = logging.getLogger(__name__)
3134
 
3235
 
3336
 @dataclass(frozen=True)
@@ -47,20 +50,31 @@ def extract_probes(sections: list[Section], *, k: int = 3, seed: int = 0) -> lis
4750
     is filled from INSTRUCTION section Q/A pairs via a deterministic
4851
     sample.
4952
     """
50
-    explicit = list(_extract_explicit_probes(sections))
53
+    parsed_pairs = _parse_instruction_sections(sections)
54
+    explicit = list(_extract_explicit_probes(sections, parsed_pairs=parsed_pairs))
5155
     if len(explicit) >= k:
5256
         return explicit[:k]
5357
 
5458
     needed = k - len(explicit)
5559
     seen_prompts = {p.prompt for p in explicit}
56
-    auto = _auto_sample_probes(sections, k=needed, seed=seed, exclude=seen_prompts)
60
+    auto = _auto_sample_probes(
61
+        sections,
62
+        k=needed,
63
+        seed=seed,
64
+        exclude=seen_prompts,
65
+        parsed_pairs=parsed_pairs,
66
+    )
5767
     return [*explicit, *auto]
5868
 
5969
 
6070
 # --- internals ---------------------------------------------------------------
6171
 
6272
 
63
-def _extract_explicit_probes(sections: list[Section]) -> list[Probe]:
73
+def _extract_explicit_probes(
74
+    sections: list[Section],
75
+    *,
76
+    parsed_pairs: dict[str, list[QAPair]],
77
+) -> list[Probe]:
6478
     """Find INSTRUCTION Q/A pairs whose question starts with `!probe`.
6579
 
6680
     The `!probe` marker appears on the Q header line; the Q body is the
@@ -72,16 +86,7 @@ def _extract_explicit_probes(sections: list[Section]) -> list[Probe]:
7286
     for section in sections:
7387
         if section.type is not SectionType.INSTRUCTION:
7488
             continue
75
-        try:
76
-            pairs = parse_instruction_body(
77
-                _normalize_probe_markers(section.content),
78
-                section_id=section.section_id,
79
-            )
80
-        except Exception:
81
-            # Malformed instruction bodies are the instruction-parser's
82
-            # problem; probe extraction is best-effort and must not hide
83
-            # grammar errors by raising here.
84
-            continue
89
+        pairs = parsed_pairs.get(section.section_id, [])
8590
         for pair in pairs:
8691
             # After normalization every probe pair sits in a private
8792
             # namespace; we flag them via a sentinel prefix in the body.
@@ -127,7 +132,12 @@ def _normalize_probe_markers(body: str) -> str:
127132
 
128133
 
129134
 def _auto_sample_probes(
130
-    sections: list[Section], *, k: int, seed: int, exclude: set[str]
135
+    sections: list[Section],
136
+    *,
137
+    k: int,
138
+    seed: int,
139
+    exclude: set[str],
140
+    parsed_pairs: dict[str, list[QAPair]],
131141
 ) -> list[Probe]:
132142
     """Deterministically pick `k` questions from INSTRUCTION sections.
133143
 
@@ -148,13 +158,7 @@ def _auto_sample_probes(
148158
     for section in sections:
149159
         if section.type is not SectionType.INSTRUCTION:
150160
             continue
151
-        try:
152
-            pairs = parse_instruction_body(
153
-                _normalize_probe_markers(section.content),
154
-                section_id=section.section_id,
155
-            )
156
-        except Exception:
157
-            continue
161
+        pairs = parsed_pairs.get(section.section_id, [])
158162
         for pair in pairs:
159163
             # Skip explicit probes (their question body was prefixed
160164
             # with `!probe:` by the normalizer) — the caller handles
@@ -182,3 +186,25 @@ def _auto_sample_probes(
182186
 def _probe_sort_key(prompt: str, seed: int) -> str:
183187
     h = hashlib.sha256(f"{seed}\x00{prompt}".encode())
184188
     return h.hexdigest()
189
+
190
+
191
+def _parse_instruction_sections(sections: list[Section]) -> dict[str, list[QAPair]]:
192
+    """Parse instruction sections once so malformed blocks warn once."""
193
+    parsed: dict[str, list[QAPair]] = {}
194
+    for section in sections:
195
+        if section.type is not SectionType.INSTRUCTION:
196
+            continue
197
+        try:
198
+            parsed[section.section_id] = parse_instruction_body(
199
+                _normalize_probe_markers(section.content),
200
+                section_id=section.section_id,
201
+            )
202
+        except InstructionParseError as exc:
203
+            _LOG.warning(
204
+                "probe extraction skipped malformed instruction section %s at line %d: %s",
205
+                exc.section_id,
206
+                exc.section_line,
207
+                exc,
208
+            )
209
+            parsed[section.section_id] = []
210
+    return parsed
src/dlm/export/runner.pymodified
@@ -23,6 +23,7 @@ from pathlib import Path
2323
 from typing import TYPE_CHECKING, Any
2424
 
2525
 from dlm.export import adapter_gguf, base_gguf, merge, preflight
26
+from dlm.export.errors import ExportManifestError
2627
 from dlm.export.manifest import (
2728
     EXPORT_MANIFEST_FILENAME,
2829
     ExportManifest,
@@ -483,14 +484,15 @@ def _cached_base_matches(export_dir: Path, base_gguf_path: Path, quant: str) ->
483484
         from dlm.export.manifest import compute_sha256, load_export_manifest
484485
 
485486
         prior = load_export_manifest(export_dir)
486
-    except Exception:
487
+        if prior.quant != quant:
488
+            return False
489
+        recorded = next((a for a in prior.artifacts if a.path == base_gguf_path.name), None)
490
+        if recorded is None:
491
+            return False
492
+        return compute_sha256(base_gguf_path) == recorded.sha256
493
+    except (ExportManifestError, OSError) as exc:
494
+        _LOG.warning("export cache ignored stale manifest under %s: %s", export_dir, exc)
487495
         return False
488
-    if prior.quant != quant:
489
-        return False
490
-    recorded = next((a for a in prior.artifacts if a.path == base_gguf_path.name), None)
491
-    if recorded is None:
492
-        return False
493
-    return compute_sha256(base_gguf_path) == recorded.sha256
494496
 
495497
 
496498
 def _perform_merge_path(  # pragma: no cover
src/dlm/train/trainer.pymodified
@@ -42,6 +42,7 @@ from dlm.lock import (
4242
     DlmLock,
4343
     LockDecision,
4444
     LockMode,
45
+    LockSchemaError,
4546
     LockValidationError,
4647
     build_lock,
4748
     hardware_tier_from_backend,
@@ -1565,7 +1566,7 @@ def _validate_or_abort_lock(
15651566
     )
15661567
     try:
15671568
         prior = load_lock(store.root)
1568
-    except Exception:
1569
+    except LockSchemaError as exc:
15691570
         # Audit-05 N5: a corrupt `dlm.lock` on disk would normally kill
15701571
         # the run at load time. Under `--update-lock` the operator has
15711572
         # explicitly opted to overwrite the file; treat the parse
@@ -1575,6 +1576,11 @@ def _validate_or_abort_lock(
15751576
         # "don't touch the file").
15761577
         if lock_mode != "update":
15771578
             raise
1579
+        _LOG.warning(
1580
+            "update-lock: ignoring unreadable prior dlm.lock at %s: %s",
1581
+            store.root,
1582
+            exc,
1583
+        )
15781584
         prior = None
15791585
     decision = validate_lock(prior, candidate, mode=lock_mode)
15801586
 
tests/unit/eval/test_mode_split.pymodified
@@ -2,10 +2,12 @@
22
 
33
 from __future__ import annotations
44
 
5
-from types import SimpleNamespace
5
+import logging
66
 from typing import Any
77
 from unittest.mock import MagicMock
88
 
9
+import pytest
10
+
911
 from dlm.eval.mode_split import compute_val_loss_by_mode
1012
 
1113
 
@@ -123,15 +125,21 @@ class TestModeClassification:
123125
 
124126
 
125127
 class TestEvalFailures:
126
-    def test_evaluate_exception_yields_none(self) -> None:
128
+    def test_evaluate_exception_yields_none(
129
+        self,
130
+        caplog: pytest.LogCaptureFixture,
131
+    ) -> None:
127132
         """A stack-version skew that makes evaluate() raise shouldn't
128133
         crash training — the affected mode just stays None."""
134
+        caplog.set_level(logging.WARNING, logger="dlm.eval.mode_split")
129135
         trainer = MagicMock()
130136
         trainer.evaluate.side_effect = RuntimeError("TRL drift")
131137
         val = _FakeDataset([{"text": "a"}, {"messages": []}])
132138
         cpt, sft = compute_val_loss_by_mode(trainer, val)
133139
         assert cpt is None
134140
         assert sft is None
141
+        assert "val-loss split skipped cpt evaluation" in caplog.text
142
+        assert "val-loss split skipped sft evaluation" in caplog.text
135143
 
136144
     def test_missing_eval_loss_key_yields_none(self) -> None:
137145
         trainer = MagicMock()
@@ -141,15 +149,28 @@ class TestEvalFailures:
141149
         assert cpt is None
142150
         assert sft is None
143151
 
144
-    def test_select_failure_yields_none(self) -> None:
152
+    def test_select_failure_yields_none(
153
+        self,
154
+        caplog: pytest.LogCaptureFixture,
155
+    ) -> None:
156
+        caplog.set_level(logging.WARNING, logger="dlm.eval.mode_split")
145157
         trainer = MagicMock()
146158
         trainer.evaluate.return_value = {"eval_loss": 0.0}
147
-        # A dataset without a .select method — the helper should swallow.
148
-        bad_val = SimpleNamespace(
149
-            __len__=lambda: 1,
150
-            __iter__=lambda: iter([{"text": "a"}]),
151
-        )
159
+        # Dataset iteration works, but subset selection does not.
160
+        bad_val = _NoSelectDataset([{"text": "a"}])
152161
         cpt, sft = compute_val_loss_by_mode(trainer, bad_val)
153162
         # Both None — the helper couldn't build subsets.
154163
         assert cpt is None
155164
         assert sft is None
165
+        assert "val-loss split skipped cpt subset selection" in caplog.text
166
+
167
+
168
+class _NoSelectDataset:
169
+    def __init__(self, rows: list[dict[str, Any]]) -> None:
170
+        self._rows = rows
171
+
172
+    def __len__(self) -> int:
173
+        return len(self._rows)
174
+
175
+    def __iter__(self):  # type: ignore[no-untyped-def]
176
+        return iter(self._rows)
tests/unit/eval/test_probes.pymodified
@@ -3,6 +3,7 @@
33
 from __future__ import annotations
44
 
55
 import dataclasses
6
+import logging
67
 
78
 import pytest
89
 
@@ -73,6 +74,17 @@ class TestAutoSample:
7374
         s = Section(type=SectionType.INSTRUCTION, content=body)
7475
         assert extract_probes([s], k=0) == []
7576
 
77
+    def test_malformed_instruction_logs_warning_once(
78
+        self,
79
+        caplog: pytest.LogCaptureFixture,
80
+    ) -> None:
81
+        body = "### Q\nunterminated question"
82
+        s = Section(type=SectionType.INSTRUCTION, content=body)
83
+        caplog.set_level(logging.WARNING, logger="dlm.eval.probes")
84
+        assert extract_probes([s], k=3) == []
85
+        assert "probe extraction skipped malformed instruction section" in caplog.text
86
+        assert len(caplog.records) == 1
87
+
7688
 
7789
 class TestProbeDataclass:
7890
     def test_probe_is_frozen(self) -> None:
tests/unit/export/test_runner.pymodified
@@ -3,6 +3,7 @@
33
 from __future__ import annotations
44
 
55
 import json
6
+import logging
67
 from pathlib import Path
78
 from typing import Any
89
 
@@ -184,6 +185,30 @@ class TestCaching:
184185
         assert len(recorder2.commands) == 1
185186
         assert any("convert_lora_to_gguf.py" in str(a) for a in recorder2.commands[0])
186187
 
188
+    def test_bad_cached_manifest_logs_warning_and_rebuilds(
189
+        self,
190
+        tmp_path: Path,
191
+        monkeypatch: pytest.MonkeyPatch,
192
+        caplog: pytest.LogCaptureFixture,
193
+    ) -> None:
194
+        from dlm.export.errors import ExportManifestError
195
+        from dlm.export.runner import _cached_base_matches
196
+
197
+        export_dir = tmp_path / "exports" / "Q4_K_M"
198
+        export_dir.mkdir(parents=True)
199
+        base_gguf = export_dir / "base.Q4_K_M.gguf"
200
+        base_gguf.write_bytes(b"cached bytes")
201
+        (export_dir / "export_manifest.json").write_text("{}", encoding="utf-8")
202
+
203
+        def _raise(_export_dir: Path) -> object:
204
+            raise ExportManifestError("bad manifest")
205
+
206
+        monkeypatch.setattr("dlm.export.manifest.load_export_manifest", _raise)
207
+        caplog.set_level(logging.WARNING, logger="dlm.export.runner")
208
+
209
+        assert _cached_base_matches(export_dir, base_gguf, "Q4_K_M") is False
210
+        assert "export cache ignored stale manifest" in caplog.text
211
+
187212
 
188213
 class TestMergeGate:
189214
     def test_qlora_merge_without_dequantize_raises(self, tmp_path: Path) -> None:
tests/unit/train/test_lock_wiring.pymodified
@@ -2,6 +2,7 @@
22
 
33
 from __future__ import annotations
44
 
5
+import logging
56
 from pathlib import Path
67
 from types import SimpleNamespace
78
 from typing import Any
@@ -180,3 +181,29 @@ class TestUpdateModeOverrides:
180181
         updated = load_lock(store.root)
181182
         assert updated is not None
182183
         assert updated.base_model_revision == spec.revision
184
+
185
+    def test_update_mode_warns_and_recovers_from_broken_lock(
186
+        self,
187
+        tmp_path: Path,
188
+        caplog: pytest.LogCaptureFixture,
189
+    ) -> None:
190
+        store = _bootstrap_store(tmp_path)
191
+        parsed = _parsed(tmp_path)
192
+        spec = BASE_MODELS["smollm2-135m"]
193
+        store_lock = store.root / "dlm.lock"
194
+        store_lock.write_text("{not json", encoding="utf-8")
195
+
196
+        caplog.set_level(logging.WARNING, logger="dlm.train.trainer")
197
+        run(
198
+            store,
199
+            parsed,
200
+            spec,
201
+            _plan(),
202
+            trainer_factory=_mock_trainer_factory,
203
+            lock_mode="update",
204
+        )
205
+
206
+        updated = load_lock(store.root)
207
+        assert updated is not None
208
+        assert updated.base_model_revision == spec.revision
209
+        assert "update-lock: ignoring unreadable prior dlm.lock" in caplog.text