Python · 11345 bytes Raw Blame History
1 """Post-SFT gate orchestration — probe extraction + run_post_sft_gate."""
2
3 from __future__ import annotations
4
5 from pathlib import Path
6 from types import MappingProxyType
7 from typing import Any
8
9 import pytest
10
11 import dlm.train.gate.orchestrator as gate_orchestrator
12 from dlm.doc.parser import ParsedDlm
13 from dlm.doc.schema import AdapterConfig, DlmFrontmatter, GateConfig, TrainingConfig
14 from dlm.doc.sections import Section, SectionType
15 from dlm.metrics.events import RunStart
16 from dlm.metrics.recorder import MetricsRecorder
17 from dlm.store.paths import StorePath
18 from dlm.train.gate.errors import GateTrainingError
19 from dlm.train.gate.orchestrator import (
20 GateProbe,
21 probes_from_sections,
22 run_post_sft_gate,
23 )
24
25
26 def _frontmatter(
27 *,
28 gate_enabled: bool = True,
29 adapters: tuple[str, ...] = ("a", "b"),
30 ) -> DlmFrontmatter:
31 adapter_map = {name: AdapterConfig(lora_r=4) for name in adapters} if adapters else None
32 return DlmFrontmatter(
33 dlm_id="01HRSHWZ" + "0" * 18,
34 dlm_version=8,
35 base_model="smollm2-135m",
36 training=TrainingConfig(
37 adapters=adapter_map,
38 gate=GateConfig(
39 enabled=gate_enabled,
40 cold_start_floor=4,
41 steps=20, # short for unit tests
42 ),
43 ),
44 )
45
46
47 def _instruction(content: str, *, adapter: str | None) -> Section:
48 return Section(
49 type=SectionType.INSTRUCTION,
50 content=content,
51 start_line=0,
52 adapter=adapter,
53 tags=MappingProxyType({}),
54 )
55
56
57 def _prose(content: str, *, adapter: str | None) -> Section:
58 return Section(
59 type=SectionType.PROSE,
60 content=content,
61 start_line=0,
62 adapter=adapter,
63 tags=MappingProxyType({}),
64 )
65
66
67 def _preference(content: str, *, adapter: str | None) -> Section:
68 return Section(
69 type=SectionType.PREFERENCE,
70 content=content,
71 start_line=0,
72 adapter=adapter,
73 tags=MappingProxyType({}),
74 )
75
76
77 def _parsed(sections: tuple[Section, ...], **fm_kwargs: object) -> ParsedDlm:
78 return ParsedDlm(
79 frontmatter=_frontmatter(**fm_kwargs), # type: ignore[arg-type]
80 sections=sections,
81 source_path=None,
82 )
83
84
85 class TestProbesFromSections:
86 def test_drops_untagged_sections(self) -> None:
87 sections = (
88 _prose("hello", adapter=None),
89 _prose("world", adapter="a"),
90 )
91 probes = probes_from_sections(_parsed(sections))
92 assert probes == [GateProbe(adapter_name="a", prompt="world")]
93
94 def test_extracts_instruction_question(self) -> None:
95 body = "### Q\nWhat is lexing?\n### A\nTurning source into tokens.\n"
96 probes = probes_from_sections(_parsed((_instruction(body, adapter="a"),)))
97 assert probes == [GateProbe(adapter_name="a", prompt="What is lexing?")]
98
99 def test_multiple_qa_uses_first_pair(self) -> None:
100 body = (
101 "### Q\nFirst question?\n### A\nFirst answer.\n\n"
102 "### Q\nSecond question?\n### A\nSecond answer.\n"
103 )
104 probes = probes_from_sections(_parsed((_instruction(body, adapter="b"),)))
105 assert probes[0].prompt == "First question?"
106
107 def test_extracts_preference_prompt(self) -> None:
108 body = "### Prompt\nWhich answer is better?\n### Chosen\nA\n### Rejected\nB\n"
109 probes = probes_from_sections(_parsed((_preference(body, adapter="b"),)))
110 assert probes == [GateProbe(adapter_name="b", prompt="Which answer is better?")]
111
112 def test_unparseable_instruction_is_skipped(self, caplog: pytest.LogCaptureFixture) -> None:
113 probes = probes_from_sections(_parsed((_instruction("no Q/A pairs here", adapter="a"),)))
114 assert probes == []
115
116 def test_prose_truncates_to_cap(self) -> None:
117 long = "x" * 5000
118 probes = probes_from_sections(_parsed((_prose(long, adapter="a"),)))
119 assert len(probes) == 1
120 assert len(probes[0].prompt) == 2048
121
122
123 class TestRunPostSftGate:
124 def test_disabled_gate_returns_none(self, tmp_path: Path) -> None:
125 parsed = _parsed(
126 (_prose("x", adapter="a"),),
127 gate_enabled=False,
128 )
129 store = StorePath(root=tmp_path)
130 store.ensure_layout()
131 recorder = MetricsRecorder(tmp_path)
132 recorder.record_run_start(RunStart(run_id=1, adapter_version=1, phase="sft", seed=42))
133 result = run_post_sft_gate(
134 store,
135 parsed,
136 run_id=1,
137 recorder=recorder,
138 embed=lambda _p: _tensor(4),
139 input_dim=4,
140 )
141 assert result is None
142
143 def test_single_adapter_returns_none(self, tmp_path: Path) -> None:
144 # A single-named-adapter doc can't carry an enabled gate (the
145 # schema refuses it), so build the frontmatter with no adapter
146 # map at all to simulate a gate that's "enabled" but has
147 # nothing to route between.
148 parsed = ParsedDlm(
149 frontmatter=DlmFrontmatter(
150 dlm_id="01HRSHWZ" + "0" * 18,
151 dlm_version=8,
152 base_model="smollm2-135m",
153 training=TrainingConfig(
154 adapters=None,
155 gate=GateConfig(enabled=False),
156 ),
157 ),
158 sections=(_prose("x", adapter="a"),),
159 source_path=None,
160 )
161 store = StorePath(root=tmp_path)
162 store.ensure_layout()
163 recorder = MetricsRecorder(tmp_path)
164 result = run_post_sft_gate(
165 store,
166 parsed,
167 run_id=1,
168 recorder=recorder,
169 embed=lambda _p: _tensor(4),
170 input_dim=4,
171 )
172 assert result is None
173
174 def test_exactly_one_named_adapter_returns_none(self, tmp_path: Path) -> None:
175 parsed = _parsed((_prose("x", adapter="solo"),), gate_enabled=False, adapters=("solo",))
176 object.__setattr__(parsed.frontmatter.training.gate, "enabled", True)
177 store = StorePath(root=tmp_path)
178 store.ensure_layout()
179 recorder = MetricsRecorder(tmp_path)
180 result = run_post_sft_gate(
181 store,
182 parsed,
183 run_id=1,
184 recorder=recorder,
185 embed=lambda _p: _tensor(4),
186 input_dim=4,
187 )
188 assert result is None
189
190 def test_cold_start_fallback_records_uniform_events(self, tmp_path: Path) -> None:
191 parsed = _parsed((_prose("only-a", adapter="a"),))
192 store = StorePath(root=tmp_path)
193 store.ensure_layout()
194 recorder = MetricsRecorder(tmp_path)
195 recorder.record_run_start(RunStart(run_id=1, adapter_version=1, phase="sft", seed=42))
196 result = run_post_sft_gate(
197 store,
198 parsed,
199 run_id=1,
200 recorder=recorder,
201 embed=lambda _p: _tensor(4),
202 input_dim=4,
203 )
204 assert result is not None
205 assert result.mode == "uniform"
206
207 # Gate config file was written (uniform mode).
208 from dlm.train.gate.paths import gate_config_path
209
210 assert gate_config_path(store).exists()
211
212 # Events for every declared adapter, each with mean_weight = 1/N.
213 from dlm.metrics.db import connect
214
215 with connect(tmp_path) as conn:
216 rows = list(
217 conn.execute(
218 "SELECT adapter_name, mean_weight, sample_count, mode "
219 "FROM gate_events WHERE run_id = 1 ORDER BY adapter_name"
220 )
221 )
222 assert len(rows) == 2
223 assert {name for name, _w, _c, _m in rows} == {"a", "b"}
224 for _name, weight, _count, mode in rows:
225 assert mode == "uniform"
226 assert weight == pytest.approx(0.5)
227
228 def test_trained_mode_records_calibrated_mean_weight(self, tmp_path: Path) -> None:
229 # Enough supervising samples for both adapters. Use two clear
230 # clusters so training actually separates them.
231 import torch
232
233 sections: list[Section] = []
234 for i in range(6):
235 sections.append(_prose(f"alpha-{i}", adapter="a"))
236 sections.append(_prose(f"beta-{i}", adapter="b"))
237 parsed = _parsed(tuple(sections))
238
239 def embed(prompt: str) -> torch.Tensor:
240 # Cluster 'alpha' at +1, 'beta' at -1.
241 sign = 1.0 if prompt.startswith("alpha") else -1.0
242 return sign * torch.ones(4) + 0.05 * torch.randn(4)
243
244 store = StorePath(root=tmp_path)
245 store.ensure_layout()
246 recorder = MetricsRecorder(tmp_path)
247 recorder.record_run_start(RunStart(run_id=1, adapter_version=1, phase="sft", seed=42))
248 result = run_post_sft_gate(
249 store,
250 parsed,
251 run_id=1,
252 recorder=recorder,
253 embed=embed,
254 input_dim=4,
255 )
256 assert result is not None
257 assert result.mode == "trained"
258 # Calibrated mean weights should be (approximately) the prior
259 # split given a balanced supervision set — the average weight
260 # summed across adapters is always 1.0.
261 names = tuple(result.per_adapter_mean_weight.keys())
262 assert set(names) == {"a", "b"}
263 total = sum(result.per_adapter_mean_weight.values())
264 assert total == pytest.approx(1.0, abs=1e-4)
265
266 # gate_events rows reflect the calibrated weights.
267 from dlm.metrics.db import connect
268
269 with connect(tmp_path) as conn:
270 rows = dict(
271 conn.execute(
272 "SELECT adapter_name, mean_weight FROM gate_events WHERE run_id = 1"
273 ).fetchall()
274 )
275 assert rows["a"] == pytest.approx(result.per_adapter_mean_weight["a"])
276 assert rows["b"] == pytest.approx(result.per_adapter_mean_weight["b"])
277
278 def test_gate_training_error_returns_none(
279 self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
280 ) -> None:
281 parsed = _parsed(
282 (
283 _prose("alpha", adapter="a"),
284 _prose("beta", adapter="b"),
285 )
286 )
287 store = StorePath(root=tmp_path)
288 store.ensure_layout()
289 recorder = MetricsRecorder(tmp_path)
290 recorder.record_run_start(RunStart(run_id=1, adapter_version=1, phase="sft", seed=42))
291
292 def _raise_gate_error(*args: object, **kwargs: object) -> None:
293 raise GateTrainingError("boom")
294
295 monkeypatch.setattr(gate_orchestrator, "train_gate", _raise_gate_error)
296 result = run_post_sft_gate(
297 store,
298 parsed,
299 run_id=1,
300 recorder=recorder,
301 embed=lambda _p: _tensor(4),
302 input_dim=4,
303 )
304 assert result is None
305
306 # Divergence emits one `mode="diverged"` GateEvent per declared
307 # adapter so `dlm show` surfaces the failure instead of silently
308 # skipping the gate.
309 from dlm.metrics import queries as _queries
310
311 events = _queries.gate_events_for_run(tmp_path, 1)
312 assert {e.adapter_name for e in events} == {"a", "b"}
313 assert all(e.mode == "diverged" for e in events)
314 assert all(e.mean_weight == 0.0 and e.sample_count == 0 for e in events)
315
316
317 def _tensor(d: int) -> Any:
318 import torch
319
320 return torch.zeros(d)