Python · 3761 bytes Raw Blame History
1 """Metrics recorder for the learned adapter gate."""
2
3 from __future__ import annotations
4
5 import sqlite3
6 from pathlib import Path
7
8 import pytest
9
10 from dlm.metrics.db import connect
11 from dlm.metrics.events import GateEvent, RunStart
12 from dlm.metrics.recorder import MetricsRecorder
13
14
15 @pytest.fixture
16 def recorder(tmp_path: Path) -> MetricsRecorder:
17 rec = MetricsRecorder(store_root=tmp_path)
18 # gate_events references runs(run_id) so we need a run row first.
19 rec.record_run_start(RunStart(run_id=1, adapter_version=1, phase="sft", seed=42))
20 return rec
21
22
23 class TestGateEventSchema:
24 def test_populated_at(self) -> None:
25 event = GateEvent(
26 run_id=1, adapter_name="lexer", mean_weight=0.7, sample_count=32, mode="trained"
27 )
28 assert event.at != ""
29 assert "T" in event.at # iso-8601 shape
30
31 def test_explicit_at_preserved(self) -> None:
32 event = GateEvent(
33 run_id=1,
34 adapter_name="runtime",
35 mean_weight=0.3,
36 sample_count=16,
37 mode="trained",
38 at="2026-04-21T12:00:00+00:00",
39 )
40 assert event.at == "2026-04-21T12:00:00+00:00"
41
42
43 class TestRecordGate:
44 def test_insert_and_read_back(self, recorder: MetricsRecorder, tmp_path: Path) -> None:
45 recorder.record_gate(
46 GateEvent(
47 run_id=1,
48 adapter_name="lexer",
49 mean_weight=0.7,
50 sample_count=32,
51 mode="trained",
52 )
53 )
54 recorder.record_gate(
55 GateEvent(
56 run_id=1,
57 adapter_name="runtime",
58 mean_weight=0.3,
59 sample_count=16,
60 mode="trained",
61 )
62 )
63 with connect(tmp_path) as conn:
64 rows = list(
65 conn.execute(
66 "SELECT adapter_name, mean_weight, sample_count, mode "
67 "FROM gate_events WHERE run_id = 1 ORDER BY adapter_name"
68 )
69 )
70 assert rows == [("lexer", 0.7, 32, "trained"), ("runtime", 0.3, 16, "trained")]
71
72 def test_replace_on_duplicate(self, recorder: MetricsRecorder, tmp_path: Path) -> None:
73 """Primary key is (run_id, adapter_name) — inserting twice for the
74 same pair overwrites, not duplicates."""
75 recorder.record_gate(
76 GateEvent(
77 run_id=1, adapter_name="lexer", mean_weight=0.5, sample_count=10, mode="trained"
78 )
79 )
80 recorder.record_gate(
81 GateEvent(
82 run_id=1, adapter_name="lexer", mean_weight=0.9, sample_count=20, mode="trained"
83 )
84 )
85 with connect(tmp_path) as conn:
86 rows = list(
87 conn.execute("SELECT mean_weight, sample_count FROM gate_events WHERE run_id = 1")
88 )
89 assert rows == [(0.9, 20)]
90
91 def test_uniform_mode_recorded(self, recorder: MetricsRecorder, tmp_path: Path) -> None:
92 recorder.record_gate(
93 GateEvent(run_id=1, adapter_name="a", mean_weight=0.5, sample_count=2, mode="uniform")
94 )
95 with connect(tmp_path) as conn:
96 (mode,) = next(
97 conn.execute("SELECT mode FROM gate_events WHERE run_id = 1 AND adapter_name='a'")
98 )
99 assert mode == "uniform"
100
101
102 def test_schema_includes_gate_events_table(tmp_path: Path) -> None:
103 """The migration path creates the table unconditionally on connect."""
104 with connect(tmp_path) as conn:
105 assert isinstance(conn, sqlite3.Connection)
106 tables = {
107 row[0] for row in conn.execute("SELECT name FROM sqlite_master WHERE type='table'")
108 }
109 assert "gate_events" in tables