tenseleyflow/documentlanguagemodel / 5f8f3a8

Browse files

Filter preference rows from replay in SFT build_dataset

Authored by mfwolffe <wolffemf@dukes.jmu.edu>
SHA
5f8f3a87fdd98859ecc71ff4ece371f04726f93d
Parents
b8ca343
Tree
b07ab60

2 changed files

StatusFile+-
M src/dlm/data/dataset_builder.py 9 1
M tests/unit/data/test_dataset_builder.py 17 0
src/dlm/data/dataset_builder.pymodified
@@ -64,7 +64,7 @@ def build_dataset(
6464
         audio_token=audio_token,
6565
     )
6666
     if replay_rows is not None:
67
-        rows.extend(replay_rows)
67
+        rows.extend(r for r in replay_rows if not _is_preference_row(r))
6868
 
6969
     if not rows:
7070
         raise ValueError(
@@ -79,3 +79,11 @@ def build_dataset(
7979
             )
8080
 
8181
     return split(rows, val_frac=val_frac, seed=seed)
82
+
83
+
84
+def _is_preference_row(row: Row) -> bool:
85
+    return (
86
+        row.get("prompt") is not None
87
+        and row.get("chosen") is not None
88
+        and row.get("rejected") is not None
89
+    )
tests/unit/data/test_dataset_builder.pymodified
@@ -51,6 +51,23 @@ class TestBuildDataset:
5151
         all_text = {r["text"] for r in list(train) + list(val)}
5252
         assert {"source doc prose", "replay-1", "replay-2"}.issubset(all_text)
5353
 
54
+    def test_preference_replay_rows_filtered(self) -> None:
55
+        sections = [_s(SectionType.PROSE, "source doc prose")]
56
+        replay = [
57
+            {"text": "sft-replay", "_dlm_section_id": "replay-sft"},
58
+            {
59
+                "prompt": "q",
60
+                "chosen": "good",
61
+                "rejected": "bad",
62
+                "_dlm_section_id": "replay-pref",
63
+            },
64
+        ]
65
+        train, val = build_dataset(sections, seed=0, val_frac=0.1, replay_rows=replay)
66
+        all_rows = list(train) + list(val)
67
+        all_text = {r.get("text") for r in all_rows if r.get("text")}
68
+        assert "sft-replay" in all_text
69
+        assert not any(r.get("prompt") == "q" for r in all_rows)
70
+
5471
     def test_empty_rows_raises(self) -> None:
5572
         sections = [_s(SectionType.PROSE, "   ")]
5673
         with pytest.raises(ValueError, match="no trainable rows"):