tenseleyflow/documentlanguagemodel / 35c558f

Browse files

Record preference mine metrics

Authored by espadonne
SHA
35c558faed456e06672a25accc894e3f9df52537
Parents
dcb09c5
Tree
6f96a60

10 changed files

StatusFile+-
M src/dlm/cli/commands.py 19 0
M src/dlm/metrics/__init__.py 2 0
M src/dlm/metrics/db.py 12 0
M src/dlm/metrics/events.py 18 0
M src/dlm/metrics/queries.py 62 0
M src/dlm/metrics/recorder.py 20 0
M tests/unit/cli/test_preference_cmd.py 9 0
M tests/unit/metrics/test_db_schema.py 10 1
M tests/unit/metrics/test_queries.py 60 1
M tests/unit/metrics/test_recorder.py 33 1
src/dlm/cli/commands.pymodified
@@ -3734,6 +3734,8 @@ def preference_mine_cmd(
37343734
         build_backend,
37353735
         select_backend,
37363736
     )
3737
+    from dlm.metrics import MetricsRecorder, PreferenceMineEvent
3738
+    from dlm.metrics.events import PreferenceMineWriteMode
37373739
     from dlm.modality import modality_for
37383740
     from dlm.preference import (
37393741
         InvalidJudgeSpecError,
@@ -3858,10 +3860,25 @@ def preference_mine_cmd(
38583860
     finally:
38593861
         backend_obj.unload()
38603862
 
3863
+    recorder = MetricsRecorder(store.root)
3864
+
3865
+    def _record_preference_mine(write_mode: PreferenceMineWriteMode) -> None:
3866
+        recorder.record_preference_mine(
3867
+            PreferenceMineEvent(
3868
+                run_id=run_id,
3869
+                judge_name=judge_obj.name,
3870
+                sample_count=samples,
3871
+                mined_pairs=len(plan.additions),
3872
+                skipped_prompts=len(plan.skipped),
3873
+                write_mode=write_mode,
3874
+            )
3875
+        )
3876
+
38613877
     out_console.print(render_mine_plan(plan))
38623878
 
38633879
     if not plan.additions:
38643880
         clear_pending_plan(store)
3881
+        _record_preference_mine("empty")
38653882
         out_console.print(
38663883
             "\n[yellow]no candidates to mine[/yellow] — either instruction prompts "
38673884
             "did not yield a confident pair, or the matching preference sections "
@@ -3877,6 +3894,7 @@ def preference_mine_cmd(
38773894
         out_console.print(render_apply_plan(apply_plan))
38783895
         summary = apply_preference_plan(parsed, apply_plan, target=path)
38793896
         clear_pending_plan(store)
3897
+        _record_preference_mine("applied")
38803898
         out_console.print(
38813899
             f"\n[green]preference:[/green] wrote {summary.added} section(s) to {path} "
38823900
             f"({summary.skipped} skipped)"
@@ -3884,6 +3902,7 @@ def preference_mine_cmd(
38843902
         return
38853903
 
38863904
     pending = save_pending_plan(store, source_path=path.resolve(), sections=sections)
3905
+    _record_preference_mine("staged")
38873906
     out_console.print(
38883907
         f"\n[green]preference:[/green] staged {len(pending.sections)} mined preference "
38893908
         f"section(s). Run [bold]dlm preference apply {path}[/bold] to write them."
src/dlm/metrics/__init__.pymodified
@@ -12,6 +12,7 @@ from dlm.metrics.events import (
1212
     EvalEvent,
1313
     ExportEvent,
1414
     Phase,
15
+    PreferenceMineEvent,
1516
     RunEnd,
1617
     RunStart,
1718
     Status,
@@ -28,6 +29,7 @@ __all__ = [
2829
     "MetricsRecorder",
2930
     "MetricsSchemaError",
3031
     "Phase",
32
+    "PreferenceMineEvent",
3133
     "RunEnd",
3234
     "RunStart",
3335
     "Status",
src/dlm/metrics/db.pymodified
@@ -96,6 +96,18 @@ _SCHEMA_SQL = [
9696
         PRIMARY KEY (run_id, adapter_name)
9797
     )
9898
     """,
99
+    """
100
+    CREATE TABLE IF NOT EXISTS preference_mining (
101
+        event_id         INTEGER PRIMARY KEY AUTOINCREMENT,
102
+        run_id           INTEGER NOT NULL,
103
+        judge_name       TEXT NOT NULL,
104
+        sample_count     INTEGER NOT NULL,
105
+        mined_pairs      INTEGER NOT NULL,
106
+        skipped_prompts  INTEGER NOT NULL,
107
+        write_mode       TEXT NOT NULL,
108
+        at               TEXT NOT NULL
109
+    )
110
+    """,
99111
 ]
100112
 
101113
 
src/dlm/metrics/events.pymodified
@@ -12,6 +12,7 @@ from typing import Literal
1212
 
1313
 Phase = Literal["sft", "dpo", "orpo", "cpt"]
1414
 Status = Literal["running", "ok", "failed", "cancelled"]
15
+PreferenceMineWriteMode = Literal["staged", "applied", "empty"]
1516
 
1617
 
1718
 def _utc_iso() -> str:
@@ -130,6 +131,23 @@ class GateEvent:
130131
             object.__setattr__(self, "at", _utc_iso())
131132
 
132133
 
134
+@dataclass(frozen=True)
135
+class PreferenceMineEvent:
136
+    """Emitted from `dlm preference mine` after judging completes."""
137
+
138
+    run_id: int
139
+    judge_name: str
140
+    sample_count: int
141
+    mined_pairs: int
142
+    skipped_prompts: int
143
+    write_mode: PreferenceMineWriteMode
144
+    at: str = ""
145
+
146
+    def __post_init__(self) -> None:
147
+        if not self.at:
148
+            object.__setattr__(self, "at", _utc_iso())
149
+
150
+
133151
 @dataclass(frozen=True)
134152
 class ExportEvent:
135153
     """Emitted from `dlm export` on completion."""
src/dlm/metrics/queries.pymodified
@@ -70,6 +70,20 @@ class TokenizationRow:
7070
         return self.cache_hits / total if total else 0.0
7171
 
7272
 
73
+@dataclass(frozen=True)
74
+class PreferenceMineRow:
75
+    """One row from the `preference_mining` table."""
76
+
77
+    event_id: int
78
+    run_id: int
79
+    judge_name: str
80
+    sample_count: int
81
+    mined_pairs: int
82
+    skipped_prompts: int
83
+    write_mode: str
84
+    at: str
85
+
86
+
7387
 def recent_runs(
7488
     store_root: Path,
7589
     *,
@@ -170,6 +184,37 @@ def latest_tokenization(store_root: Path) -> TokenizationRow | None:
170184
     return TokenizationRow(*row)
171185
 
172186
 
187
+def preference_mining_for_run(store_root: Path, run_id: int) -> list[PreferenceMineRow]:
188
+    """All preference-mine events for `run_id`, oldest first."""
189
+    try:
190
+        with connect(store_root) as conn:
191
+            rows = conn.execute(
192
+                "SELECT event_id, run_id, judge_name, sample_count, mined_pairs, "
193
+                "skipped_prompts, write_mode, at "
194
+                "FROM preference_mining WHERE run_id = ? ORDER BY event_id ASC",
195
+                (run_id,),
196
+            ).fetchall()
197
+    except sqlite3.Error:
198
+        return []
199
+    return [PreferenceMineRow(*row) for row in rows]
200
+
201
+
202
+def latest_preference_mining(store_root: Path) -> PreferenceMineRow | None:
203
+    """The most-recent preference-mine row, or None when absent."""
204
+    try:
205
+        with connect(store_root) as conn:
206
+            row = conn.execute(
207
+                "SELECT event_id, run_id, judge_name, sample_count, mined_pairs, "
208
+                "skipped_prompts, write_mode, at "
209
+                "FROM preference_mining ORDER BY event_id DESC LIMIT 1"
210
+            ).fetchone()
211
+    except sqlite3.Error:
212
+        return None
213
+    if row is None:
214
+        return None
215
+    return PreferenceMineRow(*row)
216
+
217
+
173218
 @dataclass(frozen=True)
174219
 class GateEventRow:
175220
     """One row of the gate_events table (per-run per-adapter)."""
@@ -267,3 +312,20 @@ def evals_to_dict(evals: list[EvalRow]) -> list[dict[str, Any]]:
267312
         }
268313
         for e in evals
269314
     ]
315
+
316
+
317
+def preference_mining_to_dict(rows: list[PreferenceMineRow]) -> list[dict[str, Any]]:
318
+    """JSON-serializable view used by `dlm metrics --json` and `dlm show --json`."""
319
+    return [
320
+        {
321
+            "event_id": row.event_id,
322
+            "run_id": row.run_id,
323
+            "judge_name": row.judge_name,
324
+            "sample_count": row.sample_count,
325
+            "mined_pairs": row.mined_pairs,
326
+            "skipped_prompts": row.skipped_prompts,
327
+            "write_mode": row.write_mode,
328
+            "at": row.at,
329
+        }
330
+        for row in rows
331
+    ]
src/dlm/metrics/recorder.pymodified
@@ -38,6 +38,7 @@ from dlm.metrics.events import (
3838
     EvalEvent,
3939
     ExportEvent,
4040
     GateEvent,
41
+    PreferenceMineEvent,
4142
     RunEnd,
4243
     RunStart,
4344
     StepEvent,
@@ -195,6 +196,25 @@ class MetricsRecorder:
195196
 
196197
         self._with_conn(_do, failure_key="gate", hard_fail=False)
197198
 
199
+    def record_preference_mine(self, event: PreferenceMineEvent) -> None:
200
+        def _do(conn: sqlite3.Connection) -> None:
201
+            conn.execute(
202
+                "INSERT INTO preference_mining "
203
+                "(run_id, judge_name, sample_count, mined_pairs, skipped_prompts, write_mode, at) "
204
+                "VALUES (?, ?, ?, ?, ?, ?, ?)",
205
+                (
206
+                    event.run_id,
207
+                    event.judge_name,
208
+                    event.sample_count,
209
+                    event.mined_pairs,
210
+                    event.skipped_prompts,
211
+                    event.write_mode,
212
+                    event.at,
213
+                ),
214
+            )
215
+
216
+        self._with_conn(_do, failure_key="preference_mine", hard_fail=False)
217
+
198218
     def record_export(self, event: ExportEvent) -> None:
199219
         def _do(conn: sqlite3.Connection) -> None:
200220
             conn.execute(
tests/unit/cli/test_preference_cmd.pymodified
@@ -13,6 +13,7 @@ from typer.testing import CliRunner
1313
 from dlm.base_models import BaseModelSpec
1414
 from dlm.cli.app import app
1515
 from dlm.doc.parser import parse_file
16
+from dlm.metrics.queries import preference_mining_for_run
1617
 from dlm.preference.judge import PairScore
1718
 from dlm.preference.pending import load_pending_plan
1819
 from dlm.store.manifest import Manifest, TrainingRunSummary, save_manifest
@@ -163,6 +164,14 @@ class TestPreferenceCmd:
163164
         assert len(pending.sections) == 1
164165
         assert pending.sections[0].auto_mined is True
165166
 
167
+        rows = preference_mining_for_run(for_dlm(_DLM_ID, home=home).root, run_id=7)
168
+        assert len(rows) == 1
169
+        assert rows[0].judge_name == "stub:judge"
170
+        assert rows[0].sample_count == 2
171
+        assert rows[0].mined_pairs == 1
172
+        assert rows[0].skipped_prompts == 0
173
+        assert rows[0].write_mode == "staged"
174
+
166175
     def test_apply_writes_staged_preferences_and_clears_pending(
167176
         self,
168177
         tmp_path: Path,
tests/unit/metrics/test_db_schema.pymodified
@@ -22,7 +22,16 @@ class TestConnect:
2222
             tables = {
2323
                 row[0] for row in conn.execute("SELECT name FROM sqlite_master WHERE type='table'")
2424
             }
25
-        assert tables == {"runs", "steps", "evals", "exports", "tokenization", "gate_events"}
25
+        user_tables = {table for table in tables if not table.startswith("sqlite_")}
26
+        assert user_tables == {
27
+            "runs",
28
+            "steps",
29
+            "evals",
30
+            "exports",
31
+            "tokenization",
32
+            "gate_events",
33
+            "preference_mining",
34
+        }
2635
 
2736
     def test_wal_mode_enabled(self, tmp_path: Path) -> None:
2837
         with connect(tmp_path) as conn:
tests/unit/metrics/test_queries.pymodified
@@ -5,11 +5,14 @@ from __future__ import annotations
55
 from datetime import UTC, datetime, timedelta
66
 from pathlib import Path
77
 
8
-from dlm.metrics.events import EvalEvent, RunEnd, RunStart, StepEvent
8
+from dlm.metrics.events import EvalEvent, PreferenceMineEvent, RunEnd, RunStart, StepEvent
99
 from dlm.metrics.queries import (
1010
     evals_for_run,
1111
     evals_to_dict,
12
+    latest_preference_mining,
1213
     latest_run_id,
14
+    preference_mining_for_run,
15
+    preference_mining_to_dict,
1316
     recent_runs,
1417
     runs_to_dict,
1518
     steps_for_run,
@@ -27,6 +30,26 @@ def _seed(store_root: Path) -> None:
2730
             rec.record_step(StepEvent(run_id=run_id, step=step, loss=2.0 - 0.1 * step))
2831
         rec.record_eval(EvalEvent(run_id=run_id, step=30, val_loss=1.5))
2932
         rec.record_run_end(RunEnd(run_id=run_id, status="ok"))
33
+    rec.record_preference_mine(
34
+        PreferenceMineEvent(
35
+            run_id=2,
36
+            judge_name="sway",
37
+            sample_count=4,
38
+            mined_pairs=1,
39
+            skipped_prompts=0,
40
+            write_mode="staged",
41
+        )
42
+    )
43
+    rec.record_preference_mine(
44
+        PreferenceMineEvent(
45
+            run_id=2,
46
+            judge_name="hf:test/reward",
47
+            sample_count=6,
48
+            mined_pairs=2,
49
+            skipped_prompts=3,
50
+            write_mode="applied",
51
+        )
52
+    )
3053
 
3154
 
3255
 class TestRecentRuns:
@@ -101,6 +124,28 @@ class TestLatestRunId:
101124
         assert latest_run_id(tmp_path) is None
102125
 
103126
 
127
+class TestPreferenceMiningQueries:
128
+    def test_preference_mining_for_run_returns_oldest_first(self, tmp_path: Path) -> None:
129
+        _seed(tmp_path)
130
+        rows = preference_mining_for_run(tmp_path, run_id=2)
131
+        assert [row.judge_name for row in rows] == ["sway", "hf:test/reward"]
132
+        assert [row.write_mode for row in rows] == ["staged", "applied"]
133
+
134
+    def test_latest_preference_mining_returns_most_recent_event(self, tmp_path: Path) -> None:
135
+        _seed(tmp_path)
136
+        row = latest_preference_mining(tmp_path)
137
+        assert row is not None
138
+        assert row.judge_name == "hf:test/reward"
139
+        assert row.write_mode == "applied"
140
+
141
+    def test_latest_preference_mining_none_when_empty(self, tmp_path: Path) -> None:
142
+        from dlm.metrics.db import connect
143
+
144
+        with connect(tmp_path) as _conn:
145
+            pass
146
+        assert latest_preference_mining(tmp_path) is None
147
+
148
+
104149
 class TestDictSerialization:
105150
     def test_runs_to_dict_shape(self, tmp_path: Path) -> None:
106151
         _seed(tmp_path)
@@ -122,3 +167,17 @@ class TestDictSerialization:
122167
         assert all({"step", "loss", "lr", "grad_norm", "at"}.issubset(s.keys()) for s in steps)
123168
         evals = evals_to_dict(evals_for_run(tmp_path, run_id=1))
124169
         assert all("val_loss" in e for e in evals)
170
+
171
+    def test_preference_mining_to_dict_shape(self, tmp_path: Path) -> None:
172
+        _seed(tmp_path)
173
+        payload = preference_mining_to_dict(preference_mining_for_run(tmp_path, run_id=2))
174
+        assert payload[0].keys() == {
175
+            "event_id",
176
+            "run_id",
177
+            "judge_name",
178
+            "sample_count",
179
+            "mined_pairs",
180
+            "skipped_prompts",
181
+            "write_mode",
182
+            "at",
183
+        }
tests/unit/metrics/test_recorder.pymodified
@@ -11,7 +11,14 @@ from pathlib import Path
1111
 import pytest
1212
 
1313
 from dlm.metrics.db import metrics_db_path
14
-from dlm.metrics.events import EvalEvent, ExportEvent, RunEnd, RunStart, StepEvent
14
+from dlm.metrics.events import (
15
+    EvalEvent,
16
+    ExportEvent,
17
+    PreferenceMineEvent,
18
+    RunEnd,
19
+    RunStart,
20
+    StepEvent,
21
+)
1522
 from dlm.metrics.recorder import MetricsRecorder
1623
 
1724
 
@@ -121,6 +128,31 @@ class TestExports:
121128
         assert rows[0][3] == "mydoc:v1"
122129
 
123130
 
131
+class TestPreferenceMining:
132
+    def test_preference_mine_written_without_run_row(self, tmp_path: Path) -> None:
133
+        rec = MetricsRecorder(tmp_path)
134
+        rec.record_preference_mine(
135
+            PreferenceMineEvent(
136
+                run_id=7,
137
+                judge_name="sway",
138
+                sample_count=4,
139
+                mined_pairs=2,
140
+                skipped_prompts=1,
141
+                write_mode="staged",
142
+            )
143
+        )
144
+        rows = _select_all(metrics_db_path(tmp_path), "preference_mining")
145
+        assert len(rows) == 1
146
+        _, run_id, judge_name, sample_count, mined_pairs, skipped_prompts, write_mode, at = rows[0]
147
+        assert run_id == 7
148
+        assert judge_name == "sway"
149
+        assert sample_count == 4
150
+        assert mined_pairs == 2
151
+        assert skipped_prompts == 1
152
+        assert write_mode == "staged"
153
+        assert at
154
+
155
+
124156
 class TestBestEffort:
125157
     def test_step_write_logs_error_once_per_stream(
126158
         self,