Python · 10899 bytes Raw Blame History
1 """Reader-side queries against the metrics DB.
2
3 Used by `dlm metrics` and `dlm metrics watch`. All functions take a
4 `store_root: Path` and open their own read-only connection. SQLite
5 in WAL mode allows concurrent readers without blocking the trainer's
6 writer.
7 """
8
9 from __future__ import annotations
10
11 import sqlite3
12 from dataclasses import dataclass
13 from datetime import UTC, datetime, timedelta
14 from pathlib import Path
15 from typing import Any
16
17 from dlm.metrics.db import connect
18
19
20 @dataclass(frozen=True)
21 class RunRow:
22 """Shape of one row from `runs`."""
23
24 run_id: int
25 started_at: str
26 ended_at: str | None
27 adapter_version: int | None
28 phase: str | None
29 seed: int | None
30 status: str | None
31
32
33 @dataclass(frozen=True)
34 class StepRow:
35 run_id: int
36 step: int
37 loss: float | None
38 lr: float | None
39 grad_norm: float | None
40 tokens_per_sec: float | None
41 peak_vram_mb: int | None
42 at: str
43
44
45 @dataclass(frozen=True)
46 class EvalRow:
47 run_id: int
48 step: int
49 val_loss: float | None
50 perplexity: float | None
51 retention: float | None
52 at: str
53
54
55 @dataclass(frozen=True)
56 class TokenizationRow:
57 """One row from the `tokenization` table."""
58
59 run_id: int
60 total_sections: int
61 cache_hits: int
62 cache_misses: int
63 total_tokenize_seconds: float
64 cache_bytes_after: int
65 at: str
66
67 @property
68 def hit_rate(self) -> float:
69 total = self.cache_hits + self.cache_misses
70 return self.cache_hits / total if total else 0.0
71
72
73 @dataclass(frozen=True)
74 class PreferenceMineRow:
75 """One row from the `preference_mining` table."""
76
77 event_id: int
78 run_id: int
79 judge_name: str
80 sample_count: int
81 mined_pairs: int
82 skipped_prompts: int
83 write_mode: str
84 at: str
85
86
87 @dataclass(frozen=True)
88 class PreferenceMineTotals:
89 """Aggregate counts across the whole `preference_mining` table."""
90
91 run_count: int
92 event_count: int
93 total_mined_pairs: int
94 total_skipped_prompts: int
95
96
97 def recent_runs(
98 store_root: Path,
99 *,
100 limit: int = 20,
101 phase: str | None = None,
102 since: timedelta | None = None,
103 run_id: int | None = None,
104 ) -> list[RunRow]:
105 """Return the most-recent runs matching the filters.
106
107 Filters compose: `phase="sft"` AND `since=timedelta(hours=24)` AND
108 `run_id=4` are all applied. `limit` caps the result set.
109 """
110 sql = "SELECT run_id, started_at, ended_at, adapter_version, phase, seed, status FROM runs"
111 clauses: list[str] = []
112 params: list[Any] = []
113 if phase is not None:
114 clauses.append("phase = ?")
115 params.append(phase)
116 if run_id is not None:
117 clauses.append("run_id = ?")
118 params.append(run_id)
119 if since is not None:
120 cutoff = (datetime.now(UTC) - since).isoformat().replace("+00:00", "Z")
121 clauses.append("started_at >= ?")
122 params.append(cutoff)
123 if clauses:
124 sql += " WHERE " + " AND ".join(clauses)
125 sql += " ORDER BY run_id DESC LIMIT ?"
126 params.append(limit)
127
128 with connect(store_root) as conn:
129 rows = conn.execute(sql, params).fetchall()
130 return [RunRow(*row) for row in rows]
131
132
133 def steps_for_run(store_root: Path, run_id: int, *, since_step: int = 0) -> list[StepRow]:
134 """All step rows for `run_id`, ordered by step.
135
136 `since_step` is the exclusive lower bound — `dlm metrics watch`
137 uses it to poll only newly-landed rows.
138 """
139 with connect(store_root) as conn:
140 rows = conn.execute(
141 "SELECT run_id, step, loss, lr, grad_norm, tokens_per_sec, peak_vram_mb, at "
142 "FROM steps WHERE run_id = ? AND step > ? ORDER BY step ASC",
143 (run_id, since_step),
144 ).fetchall()
145 return [StepRow(*row) for row in rows]
146
147
148 def evals_for_run(store_root: Path, run_id: int, *, since_step: int = 0) -> list[EvalRow]:
149 """All eval rows for `run_id`, ordered by step."""
150 with connect(store_root) as conn:
151 rows = conn.execute(
152 "SELECT run_id, step, val_loss, perplexity, retention, at "
153 "FROM evals WHERE run_id = ? AND step > ? ORDER BY step ASC",
154 (run_id, since_step),
155 ).fetchall()
156 return [EvalRow(*row) for row in rows]
157
158
159 def tokenization_for_run(store_root: Path, run_id: int) -> TokenizationRow | None:
160 """The tokenization row for `run_id`, or None when absent.
161
162 Returns None when the table is empty for this run (i.e. the run
163 predated tokenization metrics or didn't touch the directive cache).
164 """
165 try:
166 with connect(store_root) as conn:
167 row = conn.execute(
168 "SELECT run_id, total_sections, cache_hits, cache_misses, "
169 "total_tokenize_seconds, cache_bytes_after, at "
170 "FROM tokenization WHERE run_id = ?",
171 (run_id,),
172 ).fetchone()
173 except sqlite3.Error:
174 return None
175 if row is None:
176 return None
177 return TokenizationRow(*row)
178
179
180 def latest_tokenization(store_root: Path) -> TokenizationRow | None:
181 """The most-recent tokenization row (for `dlm show`). None when
182 empty or DB missing."""
183 try:
184 with connect(store_root) as conn:
185 row = conn.execute(
186 "SELECT run_id, total_sections, cache_hits, cache_misses, "
187 "total_tokenize_seconds, cache_bytes_after, at "
188 "FROM tokenization ORDER BY run_id DESC LIMIT 1"
189 ).fetchone()
190 except sqlite3.Error:
191 return None
192 if row is None:
193 return None
194 return TokenizationRow(*row)
195
196
197 def preference_mining_for_run(store_root: Path, run_id: int) -> list[PreferenceMineRow]:
198 """All preference-mine events for `run_id`, oldest first."""
199 try:
200 with connect(store_root) as conn:
201 rows = conn.execute(
202 "SELECT event_id, run_id, judge_name, sample_count, mined_pairs, "
203 "skipped_prompts, write_mode, at "
204 "FROM preference_mining WHERE run_id = ? ORDER BY event_id ASC",
205 (run_id,),
206 ).fetchall()
207 except sqlite3.Error:
208 return []
209 return [PreferenceMineRow(*row) for row in rows]
210
211
212 def latest_preference_mining(store_root: Path) -> PreferenceMineRow | None:
213 """The most-recent preference-mine row, or None when absent."""
214 try:
215 with connect(store_root) as conn:
216 row = conn.execute(
217 "SELECT event_id, run_id, judge_name, sample_count, mined_pairs, "
218 "skipped_prompts, write_mode, at "
219 "FROM preference_mining ORDER BY event_id DESC LIMIT 1"
220 ).fetchone()
221 except sqlite3.Error:
222 return None
223 if row is None:
224 return None
225 return PreferenceMineRow(*row)
226
227
228 def preference_mining_totals(store_root: Path) -> PreferenceMineTotals | None:
229 """Aggregate counts across all preference-mine events.
230
231 Returns None when the table is absent or empty.
232 """
233 try:
234 with connect(store_root) as conn:
235 row = conn.execute(
236 "SELECT COUNT(DISTINCT run_id), COUNT(*), "
237 "COALESCE(SUM(mined_pairs), 0), COALESCE(SUM(skipped_prompts), 0) "
238 "FROM preference_mining"
239 ).fetchone()
240 except sqlite3.Error:
241 return None
242 if row is None or int(row[1]) == 0:
243 return None
244 return PreferenceMineTotals(
245 run_count=int(row[0]),
246 event_count=int(row[1]),
247 total_mined_pairs=int(row[2]),
248 total_skipped_prompts=int(row[3]),
249 )
250
251
252 @dataclass(frozen=True)
253 class GateEventRow:
254 """One row of the gate_events table (per-run per-adapter)."""
255
256 run_id: int
257 adapter_name: str
258 mean_weight: float
259 sample_count: int
260 mode: str
261 at: str
262
263
264 def gate_events_for_run(store_root: Path, run_id: int) -> list[GateEventRow]:
265 """All gate_events rows for `run_id`, ordered by adapter name.
266
267 Empty list when the run didn't record a gate or
268 `training.gate.enabled` was false.
269 """
270 try:
271 with connect(store_root) as conn:
272 rows = conn.execute(
273 "SELECT run_id, adapter_name, mean_weight, sample_count, mode, at "
274 "FROM gate_events WHERE run_id = ? ORDER BY adapter_name",
275 (run_id,),
276 ).fetchall()
277 except sqlite3.Error:
278 return []
279 return [GateEventRow(*row) for row in rows]
280
281
282 def latest_gate_events(store_root: Path) -> list[GateEventRow]:
283 """All gate_events rows for the most-recent run that recorded a
284 gate. Empty list when no run has gate-event rows yet."""
285 try:
286 with connect(store_root) as conn:
287 row = conn.execute("SELECT MAX(run_id) FROM gate_events").fetchone()
288 except sqlite3.Error:
289 return []
290 if row is None or row[0] is None:
291 return []
292 return gate_events_for_run(store_root, int(row[0]))
293
294
295 def latest_run_id(store_root: Path) -> int | None:
296 """The most-recent `run_id`, or None on empty / missing DB."""
297 try:
298 with connect(store_root) as conn:
299 row = conn.execute("SELECT MAX(run_id) FROM runs").fetchone()
300 except sqlite3.Error:
301 return None
302 if row is None or row[0] is None:
303 return None
304 return int(row[0])
305
306
307 def runs_to_dict(runs: list[RunRow]) -> list[dict[str, Any]]:
308 """JSON-serializable view used by `dlm metrics --json`."""
309 return [
310 {
311 "run_id": r.run_id,
312 "started_at": r.started_at,
313 "ended_at": r.ended_at,
314 "adapter_version": r.adapter_version,
315 "phase": r.phase,
316 "seed": r.seed,
317 "status": r.status,
318 }
319 for r in runs
320 ]
321
322
323 def steps_to_dict(steps: list[StepRow]) -> list[dict[str, Any]]:
324 return [
325 {
326 "step": s.step,
327 "loss": s.loss,
328 "lr": s.lr,
329 "grad_norm": s.grad_norm,
330 "tokens_per_sec": s.tokens_per_sec,
331 "peak_vram_mb": s.peak_vram_mb,
332 "at": s.at,
333 }
334 for s in steps
335 ]
336
337
338 def evals_to_dict(evals: list[EvalRow]) -> list[dict[str, Any]]:
339 return [
340 {
341 "step": e.step,
342 "val_loss": e.val_loss,
343 "perplexity": e.perplexity,
344 "retention": e.retention,
345 "at": e.at,
346 }
347 for e in evals
348 ]
349
350
351 def preference_mining_to_dict(rows: list[PreferenceMineRow]) -> list[dict[str, Any]]:
352 """JSON-serializable view used by `dlm metrics --json` and `dlm show --json`."""
353 return [
354 {
355 "event_id": row.event_id,
356 "run_id": row.run_id,
357 "judge_name": row.judge_name,
358 "sample_count": row.sample_count,
359 "mined_pairs": row.mined_pairs,
360 "skipped_prompts": row.skipped_prompts,
361 "write_mode": row.write_mode,
362 "at": row.at,
363 }
364 for row in rows
365 ]