Python · 13338 bytes Raw Blame History
1 """Reader-side queries against a seeded metrics DB."""
2
3 from __future__ import annotations
4
5 import sqlite3
6 from datetime import UTC, datetime, timedelta
7 from pathlib import Path
8
9 import pytest
10
11 from dlm.metrics.events import (
12 EvalEvent,
13 GateEvent,
14 PreferenceMineEvent,
15 RunEnd,
16 RunStart,
17 StepEvent,
18 TokenizationEvent,
19 )
20 from dlm.metrics.queries import (
21 evals_for_run,
22 evals_to_dict,
23 gate_events_for_run,
24 latest_gate_events,
25 latest_preference_mining,
26 latest_run_id,
27 latest_tokenization,
28 preference_mining_for_run,
29 preference_mining_to_dict,
30 preference_mining_totals,
31 recent_runs,
32 runs_to_dict,
33 steps_for_run,
34 steps_to_dict,
35 tokenization_for_run,
36 )
37 from dlm.metrics.recorder import MetricsRecorder
38
39
40 def _seed(store_root: Path) -> None:
41 """Populate a DB with three runs and a handful of steps/evals."""
42 rec = MetricsRecorder(store_root)
43 for run_id in (1, 2, 3):
44 rec.record_run_start(RunStart(run_id=run_id, adapter_version=run_id, phase="sft", seed=42))
45 for step in (10, 20, 30):
46 rec.record_step(StepEvent(run_id=run_id, step=step, loss=2.0 - 0.1 * step))
47 rec.record_eval(EvalEvent(run_id=run_id, step=30, val_loss=1.5))
48 rec.record_run_end(RunEnd(run_id=run_id, status="ok"))
49 rec.record_tokenization(
50 TokenizationEvent(
51 run_id=3,
52 total_sections=10,
53 cache_hits=7,
54 cache_misses=3,
55 total_tokenize_seconds=0.75,
56 cache_bytes_after=4096,
57 )
58 )
59 rec.record_gate(
60 GateEvent(
61 run_id=2,
62 adapter_name="tone",
63 mean_weight=0.8,
64 sample_count=12,
65 mode="trained",
66 )
67 )
68 rec.record_gate(
69 GateEvent(
70 run_id=2,
71 adapter_name="facts",
72 mean_weight=0.2,
73 sample_count=12,
74 mode="trained",
75 )
76 )
77 rec.record_preference_mine(
78 PreferenceMineEvent(
79 run_id=2,
80 judge_name="sway",
81 sample_count=4,
82 mined_pairs=1,
83 skipped_prompts=0,
84 write_mode="staged",
85 )
86 )
87 rec.record_preference_mine(
88 PreferenceMineEvent(
89 run_id=2,
90 judge_name="hf:test/reward",
91 sample_count=6,
92 mined_pairs=2,
93 skipped_prompts=3,
94 write_mode="applied",
95 )
96 )
97
98
99 class TestRecentRuns:
100 def test_returns_runs_newest_first(self, tmp_path: Path) -> None:
101 _seed(tmp_path)
102 runs = recent_runs(tmp_path, limit=10)
103 assert [r.run_id for r in runs] == [3, 2, 1]
104
105 def test_limit_caps_results(self, tmp_path: Path) -> None:
106 _seed(tmp_path)
107 runs = recent_runs(tmp_path, limit=2)
108 assert len(runs) == 2
109
110 def test_phase_filter(self, tmp_path: Path) -> None:
111 _seed(tmp_path)
112 runs = recent_runs(tmp_path, phase="sft")
113 assert len(runs) == 3
114 runs_dpo = recent_runs(tmp_path, phase="dpo")
115 assert runs_dpo == []
116
117 def test_run_id_filter(self, tmp_path: Path) -> None:
118 _seed(tmp_path)
119 runs = recent_runs(tmp_path, run_id=2)
120 assert len(runs) == 1
121 assert runs[0].run_id == 2
122
123 def test_since_filter_excludes_old_runs(self, tmp_path: Path) -> None:
124 _seed(tmp_path)
125 # Hack: rewrite one started_at to be far in the past.
126 import sqlite3
127
128 conn = sqlite3.connect(str(tmp_path / "metrics.sqlite"))
129 old_ts = (datetime.now(UTC) - timedelta(days=30)).isoformat().replace("+00:00", "Z")
130 conn.execute("UPDATE runs SET started_at = ? WHERE run_id = 1", (old_ts,))
131 conn.commit()
132 conn.close()
133
134 # 24h window → run 1 should drop out.
135 runs = recent_runs(tmp_path, since=timedelta(hours=24))
136 assert [r.run_id for r in runs] == [3, 2]
137
138
139 class TestStepsAndEvals:
140 def test_steps_ordered_by_step(self, tmp_path: Path) -> None:
141 _seed(tmp_path)
142 steps = steps_for_run(tmp_path, run_id=1)
143 assert [s.step for s in steps] == [10, 20, 30]
144
145 def test_steps_since_filter(self, tmp_path: Path) -> None:
146 _seed(tmp_path)
147 steps = steps_for_run(tmp_path, run_id=1, since_step=15)
148 assert [s.step for s in steps] == [20, 30]
149
150 def test_evals_for_run(self, tmp_path: Path) -> None:
151 _seed(tmp_path)
152 evals = evals_for_run(tmp_path, run_id=2)
153 assert len(evals) == 1
154 assert evals[0].val_loss == 1.5
155
156
157 class TestLatestRunId:
158 def test_returns_max(self, tmp_path: Path) -> None:
159 _seed(tmp_path)
160 assert latest_run_id(tmp_path) == 3
161
162 def test_none_when_empty(self, tmp_path: Path) -> None:
163 # Create empty DB
164 from dlm.metrics.db import connect
165
166 with connect(tmp_path) as _conn:
167 pass
168 assert latest_run_id(tmp_path) is None
169
170 def test_none_on_sqlite_error(
171 self,
172 tmp_path: Path,
173 monkeypatch: pytest.MonkeyPatch,
174 ) -> None:
175 import dlm.metrics.queries as queries_mod
176
177 def _boom(_store_root: Path) -> sqlite3.Connection:
178 raise sqlite3.OperationalError("boom")
179
180 monkeypatch.setattr(queries_mod, "connect", _boom)
181 assert latest_run_id(tmp_path) is None
182
183
184 class TestTokenizationQueries:
185 def test_tokenization_for_run_returns_row_with_hit_rate(self, tmp_path: Path) -> None:
186 _seed(tmp_path)
187 row = tokenization_for_run(tmp_path, run_id=3)
188 assert row is not None
189 assert row.cache_hits == 7
190 assert row.hit_rate == 0.7
191
192 def test_tokenization_for_run_none_when_table_has_no_row(self, tmp_path: Path) -> None:
193 from dlm.metrics.db import connect
194
195 with connect(tmp_path) as _conn:
196 pass
197 assert tokenization_for_run(tmp_path, run_id=3) is None
198
199 def test_hit_rate_zero_when_total_lookups_is_zero(self) -> None:
200 from dlm.metrics.queries import TokenizationRow
201
202 row = TokenizationRow(
203 run_id=1,
204 total_sections=0,
205 cache_hits=0,
206 cache_misses=0,
207 total_tokenize_seconds=0.0,
208 cache_bytes_after=0,
209 at="2026-01-01T00:00:00Z",
210 )
211 assert row.hit_rate == 0.0
212
213 def test_tokenization_for_run_none_on_sqlite_error(
214 self,
215 tmp_path: Path,
216 monkeypatch: pytest.MonkeyPatch,
217 ) -> None:
218 import dlm.metrics.queries as queries_mod
219
220 def _boom(_store_root: Path) -> sqlite3.Connection:
221 raise sqlite3.OperationalError("boom")
222
223 monkeypatch.setattr(queries_mod, "connect", _boom)
224 assert tokenization_for_run(tmp_path, run_id=1) is None
225
226 def test_latest_tokenization_returns_most_recent_row(self, tmp_path: Path) -> None:
227 _seed(tmp_path)
228 row = latest_tokenization(tmp_path)
229 assert row is not None
230 assert row.run_id == 3
231
232 def test_latest_tokenization_none_when_empty(self, tmp_path: Path) -> None:
233 from dlm.metrics.db import connect
234
235 with connect(tmp_path) as _conn:
236 pass
237 assert latest_tokenization(tmp_path) is None
238
239 def test_latest_tokenization_none_on_sqlite_error(
240 self,
241 tmp_path: Path,
242 monkeypatch: pytest.MonkeyPatch,
243 ) -> None:
244 import dlm.metrics.queries as queries_mod
245
246 def _boom(_store_root: Path) -> sqlite3.Connection:
247 raise sqlite3.OperationalError("boom")
248
249 monkeypatch.setattr(queries_mod, "connect", _boom)
250 assert latest_tokenization(tmp_path) is None
251
252
253 class TestGateQueries:
254 def test_gate_events_for_run_returns_rows_sorted_by_adapter(self, tmp_path: Path) -> None:
255 _seed(tmp_path)
256 rows = gate_events_for_run(tmp_path, run_id=2)
257 assert [row.adapter_name for row in rows] == ["facts", "tone"]
258
259 def test_gate_events_for_run_returns_empty_on_sqlite_error(
260 self,
261 tmp_path: Path,
262 monkeypatch: pytest.MonkeyPatch,
263 ) -> None:
264 import dlm.metrics.queries as queries_mod
265
266 def _boom(_store_root: Path) -> sqlite3.Connection:
267 raise sqlite3.OperationalError("boom")
268
269 monkeypatch.setattr(queries_mod, "connect", _boom)
270 assert gate_events_for_run(tmp_path, run_id=2) == []
271
272 def test_latest_gate_events_returns_latest_run_rows(self, tmp_path: Path) -> None:
273 _seed(tmp_path)
274 rows = latest_gate_events(tmp_path)
275 assert [row.adapter_name for row in rows] == ["facts", "tone"]
276
277 def test_latest_gate_events_empty_when_table_empty(self, tmp_path: Path) -> None:
278 from dlm.metrics.db import connect
279
280 with connect(tmp_path) as _conn:
281 pass
282 assert latest_gate_events(tmp_path) == []
283
284 def test_latest_gate_events_empty_on_sqlite_error(
285 self,
286 tmp_path: Path,
287 monkeypatch: pytest.MonkeyPatch,
288 ) -> None:
289 import dlm.metrics.queries as queries_mod
290
291 def _boom(_store_root: Path) -> sqlite3.Connection:
292 raise sqlite3.OperationalError("boom")
293
294 monkeypatch.setattr(queries_mod, "connect", _boom)
295 assert latest_gate_events(tmp_path) == []
296
297
298 class TestPreferenceMiningQueries:
299 def test_preference_mining_for_run_returns_oldest_first(self, tmp_path: Path) -> None:
300 _seed(tmp_path)
301 rows = preference_mining_for_run(tmp_path, run_id=2)
302 assert [row.judge_name for row in rows] == ["sway", "hf:test/reward"]
303 assert [row.write_mode for row in rows] == ["staged", "applied"]
304
305 def test_latest_preference_mining_returns_most_recent_event(self, tmp_path: Path) -> None:
306 _seed(tmp_path)
307 row = latest_preference_mining(tmp_path)
308 assert row is not None
309 assert row.judge_name == "hf:test/reward"
310 assert row.write_mode == "applied"
311
312 def test_latest_preference_mining_none_when_empty(self, tmp_path: Path) -> None:
313 from dlm.metrics.db import connect
314
315 with connect(tmp_path) as _conn:
316 pass
317 assert latest_preference_mining(tmp_path) is None
318
319 def test_preference_mining_totals_aggregate_across_events(self, tmp_path: Path) -> None:
320 _seed(tmp_path)
321 totals = preference_mining_totals(tmp_path)
322 assert totals is not None
323 assert totals.run_count == 1
324 assert totals.event_count == 2
325 assert totals.total_mined_pairs == 3
326 assert totals.total_skipped_prompts == 3
327
328 def test_preference_mining_for_run_returns_empty_on_sqlite_error(
329 self,
330 tmp_path: Path,
331 monkeypatch: pytest.MonkeyPatch,
332 ) -> None:
333 import dlm.metrics.queries as queries_mod
334
335 def _boom(_store_root: Path) -> sqlite3.Connection:
336 raise sqlite3.OperationalError("boom")
337
338 monkeypatch.setattr(queries_mod, "connect", _boom)
339 assert preference_mining_for_run(tmp_path, run_id=2) == []
340
341 def test_latest_preference_mining_returns_none_on_sqlite_error(
342 self,
343 tmp_path: Path,
344 monkeypatch: pytest.MonkeyPatch,
345 ) -> None:
346 import dlm.metrics.queries as queries_mod
347
348 def _boom(_store_root: Path) -> sqlite3.Connection:
349 raise sqlite3.OperationalError("boom")
350
351 monkeypatch.setattr(queries_mod, "connect", _boom)
352 assert latest_preference_mining(tmp_path) is None
353
354 def test_preference_mining_totals_none_when_table_empty(self, tmp_path: Path) -> None:
355 from dlm.metrics.db import connect
356
357 with connect(tmp_path) as _conn:
358 pass
359 assert preference_mining_totals(tmp_path) is None
360
361 def test_preference_mining_totals_none_on_sqlite_error(
362 self,
363 tmp_path: Path,
364 monkeypatch: pytest.MonkeyPatch,
365 ) -> None:
366 import dlm.metrics.queries as queries_mod
367
368 def _boom(_store_root: Path) -> sqlite3.Connection:
369 raise sqlite3.OperationalError("boom")
370
371 monkeypatch.setattr(queries_mod, "connect", _boom)
372 assert preference_mining_totals(tmp_path) is None
373
374
375 class TestDictSerialization:
376 def test_runs_to_dict_shape(self, tmp_path: Path) -> None:
377 _seed(tmp_path)
378 runs = recent_runs(tmp_path, limit=1)
379 payload = runs_to_dict(runs)
380 assert payload[0].keys() == {
381 "run_id",
382 "started_at",
383 "ended_at",
384 "adapter_version",
385 "phase",
386 "seed",
387 "status",
388 }
389
390 def test_steps_and_evals_to_dict(self, tmp_path: Path) -> None:
391 _seed(tmp_path)
392 steps = steps_to_dict(steps_for_run(tmp_path, run_id=1))
393 assert all({"step", "loss", "lr", "grad_norm", "at"}.issubset(s.keys()) for s in steps)
394 evals = evals_to_dict(evals_for_run(tmp_path, run_id=1))
395 assert all("val_loss" in e for e in evals)
396
397 def test_preference_mining_to_dict_shape(self, tmp_path: Path) -> None:
398 _seed(tmp_path)
399 payload = preference_mining_to_dict(preference_mining_for_run(tmp_path, run_id=2))
400 assert payload[0].keys() == {
401 "event_id",
402 "run_id",
403 "judge_name",
404 "sample_count",
405 "mined_pairs",
406 "skipped_prompts",
407 "write_mode",
408 "at",
409 }