tenseleyflow/documentlanguagemodel / 31d74c7

Browse files

Prove preference mine bootstrap cycle

Authored by espadonne
SHA
31d74c795dbe7d05cfcff7c5cac70fab58764868
Parents
59da9a8
Tree
4c4c5d1

4 changed files

StatusFile+-
M src/dlm/metrics/recorder.py 17 0
M tests/fixtures/trained_store.py 14 4
A tests/integration/preference/test_mine_cycle.py 252 0
M tests/unit/metrics/test_recorder.py 17 1
src/dlm/metrics/recorder.pymodified
@@ -263,6 +263,23 @@ class DlmTrainerCallback: # pragma: no cover - heavy trainer hook
263263
         self._run_id = run_id
264264
         self._step_logger = step_logger
265265
 
266
+    def __getattr__(self, name: str) -> Callable[..., None]:
267
+        """Provide no-op HF callback hooks we don't care about explicitly.
268
+
269
+        Newer `transformers` releases call more lifecycle methods via
270
+        `getattr(callback, event)` and expect every callback object to
271
+        expose the requested `on_*` attribute. We only need `on_log`
272
+        and `on_evaluate` for DLM metrics, so any other hook becomes a
273
+        harmless no-op instead of crashing the training loop.
274
+        """
275
+        if not name.startswith("on_"):
276
+            raise AttributeError(name)
277
+
278
+        def _noop(*_args: Any, **_kwargs: Any) -> None:
279
+            return None
280
+
281
+        return _noop
282
+
266283
     def on_log(
267284
         self,
268285
         _args: Any,
tests/fixtures/trained_store.pymodified
@@ -27,6 +27,8 @@ from typing import TYPE_CHECKING, Final
2727
 import pytest
2828
 
2929
 if TYPE_CHECKING:
30
+    from dlm.hardware.capabilities import Capabilities
31
+    from dlm.hardware.plan import TrainingPlan
3032
     from dlm.store.paths import StorePath
3133
 
3234
 # Small enough to keep wall-clock manageable on CPU (smollm2-135m: ~1s/step
@@ -49,6 +51,8 @@ class TrainedStoreHandle:
4951
     home: Path
5052
     dlm_id: str
5153
     store: StorePath
54
+    plan: TrainingPlan
55
+    capabilities: Capabilities
5256
 
5357
 
5458
 @pytest.fixture(scope="session")
@@ -87,10 +91,6 @@ def trained_store(tmp_path_factory: pytest.TempPathFactory) -> Iterator[TrainedS
8791
         from dlm.train import run as run_training
8892
         from tests.fixtures.dlm_factory import make_dlm
8993
 
90
-        plan = doctor().plan
91
-        if plan is None:
92
-            pytest.skip("doctor() returned no viable training plan on this host")
93
-
9494
         home = tmp_path_factory.mktemp("dlm-trained-home")
9595
         os.environ["DLM_HOME"] = str(home)
9696
 
@@ -99,6 +99,14 @@ def trained_store(tmp_path_factory: pytest.TempPathFactory) -> Iterator[TrainedS
9999
 
100100
         parsed = parse_file(doc)
101101
         spec = resolve_base_model(parsed.frontmatter.base_model)
102
+        doctor_result = doctor(
103
+            training_config=parsed.frontmatter.training,
104
+            base_params=spec.params,
105
+            seq_len=min(parsed.frontmatter.training.sequence_len, spec.effective_context_length),
106
+        )
107
+        plan = doctor_result.plan
108
+        if plan is None:
109
+            pytest.skip("doctor() returned no viable training plan on this host")
102110
         store = for_dlm(parsed.frontmatter.dlm_id)
103111
         store.ensure_layout()
104112
 
@@ -130,6 +138,8 @@ def trained_store(tmp_path_factory: pytest.TempPathFactory) -> Iterator[TrainedS
130138
             home=home,
131139
             dlm_id=parsed.frontmatter.dlm_id,
132140
             store=store,
141
+            plan=plan,
142
+            capabilities=doctor_result.capabilities,
133143
         )
134144
     finally:
135145
         for key, value in saved_env.items():
tests/integration/preference/test_mine_cycle.pyadded
@@ -0,0 +1,252 @@
1
+"""Slow integration: train → mine → train again improves held-out preference score.
2
+
3
+This is Sprint 42's bootstrap-loop proof. We keep the mined candidates
4
+deterministic (scripted backend + judge) so the test is stable, but the
5
+two preference-training passes and the final held-out SwayJudge check are
6
+real.
7
+"""
8
+
9
+from __future__ import annotations
10
+
11
+import shutil
12
+from collections import deque
13
+from pathlib import Path
14
+from typing import TYPE_CHECKING
15
+
16
+import pytest
17
+from typer.testing import CliRunner
18
+
19
+from dlm.cli.app import app
20
+from dlm.doc.parser import ParsedDlm, parse_file
21
+from dlm.doc.serializer import serialize
22
+from dlm.preference.judge import JudgeInvocationError, JudgeUnavailableError, PairScore, SwayJudge
23
+
24
+if TYPE_CHECKING:
25
+    from tests.fixtures.trained_store import TrainedStoreHandle
26
+
27
+pytestmark = pytest.mark.slow
28
+
29
+_EXTRA_BODY = """
30
+::instruction::
31
+### Q
32
+What color is grass?
33
+### A
34
+Green.
35
+
36
+::instruction::
37
+### Q
38
+What is 10 - 3?
39
+### A
40
+7.
41
+
42
+::preference::
43
+### Prompt
44
+Is water wet?
45
+### Chosen
46
+Yes.
47
+### Rejected
48
+Water is generally considered wet in everyday language.
49
+"""
50
+
51
+_MINE_RESPONSES = {
52
+    "What is 2 + 2?": ["4.", "The sum of two and two is four."],
53
+    "What is the capital of France?": [
54
+        "Paris.",
55
+        "The capital of France is Paris.",
56
+    ],
57
+    "What color is grass?": ["Green.", "Grass is usually green."],
58
+    "What is 10 - 3?": ["7.", "Ten minus three equals seven."],
59
+}
60
+
61
+_HELD_OUT_PAIRS = (
62
+    ("What is 8 + 1?", "9.", "The result of adding eight and one is nine."),
63
+    ("What color is snow?", "White.", "Snow is typically white in daylight."),
64
+    ("What is the capital of Italy?", "Rome.", "The capital city of Italy is Rome."),
65
+)
66
+
67
+
68
+class _FakeMiningBackend:
69
+    def __init__(self, responses: dict[str, list[str]]) -> None:
70
+        self._responses = {prompt: deque(items) for prompt, items in responses.items()}
71
+
72
+    def load(self, spec: object, store: object, *, adapter_name: str | None = None) -> None:
73
+        _ = spec, store, adapter_name
74
+
75
+    def generate(self, prompt: str, **_kwargs: object) -> str:
76
+        return self._responses[prompt].popleft()
77
+
78
+    def unload(self) -> None:
79
+        return None
80
+
81
+
82
+class _TerseJudge:
83
+    name = "cli:terse-judge"
84
+    suggested_threshold = 0.1
85
+
86
+    def score_pair(self, prompt: str, candidate_a: str, candidate_b: str) -> PairScore:
87
+        _ = prompt
88
+        return PairScore(score_a=-float(len(candidate_a)), score_b=-float(len(candidate_b)))
89
+
90
+
91
+def _copy_fixture_store(
92
+    trained_store: TrainedStoreHandle,
93
+    *,
94
+    tmp_path: Path,
95
+    monkeypatch: pytest.MonkeyPatch,
96
+) -> tuple[Path, object]:
97
+    from dlm.store.manifest import load_manifest, save_manifest
98
+    from dlm.store.paths import for_dlm
99
+
100
+    home = tmp_path / "home"
101
+    home.mkdir()
102
+    monkeypatch.setenv("DLM_HOME", str(home))
103
+
104
+    source_doc = trained_store.doc
105
+    doc = home / source_doc.name
106
+    shutil.copy2(source_doc, doc)
107
+
108
+    parsed = parse_file(doc)
109
+    store = for_dlm(parsed.frontmatter.dlm_id)
110
+    shutil.copytree(trained_store.store.root, store.root, dirs_exist_ok=True)
111
+
112
+    manifest = load_manifest(store.manifest)
113
+    save_manifest(
114
+        store.manifest,
115
+        manifest.model_copy(update={"source_path": doc.resolve()}),
116
+    )
117
+    return doc, store
118
+
119
+
120
+def _prepare_doc_for_cycle(doc: Path) -> None:
121
+    current = doc.read_text(encoding="utf-8")
122
+    doc.write_text(current.rstrip() + "\n\n" + _EXTRA_BODY.lstrip(), encoding="utf-8")
123
+
124
+    parsed = parse_file(doc)
125
+    new_pref = parsed.frontmatter.training.preference.model_copy(
126
+        update={"method": "orpo", "enabled": True}
127
+    )
128
+    new_training = parsed.frontmatter.training.model_copy(update={"preference": new_pref})
129
+    rewritten = ParsedDlm(
130
+        frontmatter=parsed.frontmatter.model_copy(update={"training": new_training}),
131
+        sections=parsed.sections,
132
+    )
133
+    doc.write_text(serialize(rewritten), encoding="utf-8")
134
+
135
+
136
+def _patch_mining(monkeypatch: pytest.MonkeyPatch) -> None:
137
+    monkeypatch.setattr(
138
+        "dlm.inference.backends.select_backend",
139
+        lambda *args, **kwargs: "pytorch",
140
+    )
141
+    monkeypatch.setattr(
142
+        "dlm.inference.backends.build_backend",
143
+        lambda *args, **kwargs: _FakeMiningBackend(_MINE_RESPONSES),
144
+    )
145
+    monkeypatch.setattr(
146
+        "dlm.preference.build_judge",
147
+        lambda *args, **kwargs: _TerseJudge(),
148
+    )
149
+
150
+
151
+def _mean_margin_for_version(doc: Path, store: object, version: int) -> float:
152
+    target = store.adapter_version(version)
153
+    original = store.resolve_current_adapter()
154
+    assert original is not None
155
+    store.set_current_adapter(target)
156
+    try:
157
+        judge = SwayJudge(doc)
158
+        margins = [
159
+            judge.score_pair(prompt, chosen, rejected).margin
160
+            for prompt, chosen, rejected in _HELD_OUT_PAIRS
161
+        ]
162
+    except (JudgeUnavailableError, JudgeInvocationError) as exc:
163
+        pytest.skip(f"sway judge unavailable for mine-cycle proof: {exc}")
164
+    finally:
165
+        store.set_current_adapter(original)
166
+    return sum(margins) / len(margins)
167
+
168
+
169
+@pytest.mark.slow
170
+def test_preference_mine_cycle_improves_held_out_sway_margin(
171
+    trained_store: TrainedStoreHandle,
172
+    tmp_path: Path,
173
+    monkeypatch: pytest.MonkeyPatch,
174
+) -> None:
175
+    from dlm.base_models import resolve as resolve_base_model
176
+    from dlm.doc.sections import SectionType
177
+    from dlm.store.manifest import load_manifest
178
+    from dlm.train.preference.phase_orchestrator import run_phases
179
+
180
+    doc, store = _copy_fixture_store(trained_store, tmp_path=tmp_path, monkeypatch=monkeypatch)
181
+    _prepare_doc_for_cycle(doc)
182
+
183
+    parsed = parse_file(doc)
184
+    spec = resolve_base_model(parsed.frontmatter.base_model, accept_license=True)
185
+    plan = trained_store.plan
186
+    capabilities = trained_store.capabilities
187
+
188
+    baseline = run_phases(
189
+        store,
190
+        parsed,
191
+        spec,
192
+        plan,
193
+        phase="preference",
194
+        capabilities=capabilities,
195
+        lock_mode="ignore",
196
+        seed=42,
197
+        max_steps=20,
198
+    )
199
+    assert [result.phase for result in baseline] == ["preference"]
200
+    assert baseline[0].result.adapter_version == 2
201
+
202
+    _patch_mining(monkeypatch)
203
+    runner = CliRunner()
204
+    mine_result = runner.invoke(
205
+        app,
206
+        [
207
+            "--home",
208
+            str(tmp_path / "home"),
209
+            "preference",
210
+            "mine",
211
+            str(doc),
212
+            "--samples",
213
+            "2",
214
+            "--max-pairs",
215
+            "4",
216
+            "--apply",
217
+        ],
218
+    )
219
+    assert mine_result.exit_code == 0, mine_result.output
220
+
221
+    mined_doc = parse_file(doc)
222
+    auto_mined_sections = [
223
+        section
224
+        for section in mined_doc.sections
225
+        if section.type is SectionType.PREFERENCE and section.auto_mined
226
+    ]
227
+    assert len(auto_mined_sections) == 4
228
+
229
+    final = run_phases(
230
+        store,
231
+        mined_doc,
232
+        spec,
233
+        plan,
234
+        phase="preference",
235
+        capabilities=capabilities,
236
+        lock_mode="ignore",
237
+        seed=42,
238
+        max_steps=20,
239
+    )
240
+    assert [result.phase for result in final] == ["preference"]
241
+    assert final[0].result.adapter_version == 3
242
+
243
+    manifest = load_manifest(store.manifest)
244
+    assert manifest.adapter_version == 3
245
+    assert len(manifest.training_runs) >= 3
246
+
247
+    baseline_margin = _mean_margin_for_version(doc, store, 2)
248
+    final_margin = _mean_margin_for_version(doc, store, 3)
249
+    assert final_margin > baseline_margin, (
250
+        "expected final preference-tuned adapter to improve held-out sway margin "
251
+        f"(baseline={baseline_margin:.4f}, final={final_margin:.4f})"
252
+    )
tests/unit/metrics/test_recorder.pymodified
@@ -19,7 +19,7 @@ from dlm.metrics.events import (
1919
     RunStart,
2020
     StepEvent,
2121
 )
22
-from dlm.metrics.recorder import MetricsRecorder
22
+from dlm.metrics.recorder import DlmTrainerCallback, MetricsRecorder
2323
 
2424
 
2525
 def _select_all(db_path: Path, table: str) -> list[tuple]:
@@ -220,3 +220,19 @@ class TestAnchorWrites:
220220
 
221221
         with pytest.raises(sqlite3.OperationalError, match="database is locked"):
222222
             rec.record_run_end(RunEnd(run_id=1, status="ok"))
223
+
224
+
225
+class TestTrainerCallbackCompatibility:
226
+    def test_unknown_hf_lifecycle_hooks_fall_back_to_noop(self, tmp_path: Path) -> None:
227
+        callback = DlmTrainerCallback(MetricsRecorder(tmp_path), run_id=7)
228
+
229
+        assert callable(callback.on_train_begin)
230
+        assert callable(callback.on_train_end)
231
+        assert callback.on_train_begin(None, None, None) is None
232
+        assert callback.on_train_end(None, None, None) is None
233
+
234
+    def test_non_callback_missing_attr_still_raises(self, tmp_path: Path) -> None:
235
+        callback = DlmTrainerCallback(MetricsRecorder(tmp_path), run_id=7)
236
+
237
+        with pytest.raises(AttributeError, match="not_a_callback"):
238
+            _ = callback.not_a_callback