Python · 8708 bytes Raw Blame History
1 """Turn `doc.sections.Section` objects into ready-to-train dict rows.
2
3 Current shape table:
4
5 | Section type | Row shape |
6 |---|---|
7 | PROSE | `{"text": <raw prose>}` |
8 | INSTRUCTION | one `{"messages": [{"role":"user","content":Q},{"role":"assistant","content":A}]}` per Q/A pair |
9 | PREFERENCE | one `{"prompt":P,"chosen":C,"rejected":R}` per triple |
10 | IMAGE | `{"images": [PIL.Image], "text": "<image>\\n<caption>"}` — matches TRL 1.2's `DataCollatorForVisionLanguageModeling` standard-LM contract |
11 | AUDIO | `{"audio_blob_sha": sha, "audio_path": str, "text": "<|AUDIO|>\\n<transcript>"}` — path-based (TRL has no audio auto-dispatch; a custom collator resolves the blob and drives `preprocess_audio` at collate time) |
12
13 IMAGE / AUDIO emission requires a `BlobStore` (to resolve
14 `media_blob_sha` into bytes) and the base's placeholder token.
15 Callers that leave `blob_store=None` with media sections in the
16 input raise `ValueError` — the row shape isn't viable without the
17 actual bytes. Audio rows hold only the path + sha, not the decoded
18 waveform; the audio cache is the right place to hold preprocessed
19 features across epochs, and loading lazily at collate time keeps
20 dataset rows small.
21
22 Every row carries `_dlm_section_id` so `splitter.split()` can key
23 deterministically on (seed, section_id) rather than row index. This is
24 what makes adding sections to a document *not* reshuffle the existing
25 train/val assignments.
26 """
27
28 from __future__ import annotations
29
30 from typing import TYPE_CHECKING, Any
31
32 from dlm.data.instruction_parser import parse_instruction_body
33 from dlm.data.preference_parser import parse_preference_body
34 from dlm.doc.sections import Section, SectionType
35
36 if TYPE_CHECKING:
37 from dlm.store.blobs import BlobStore
38
39 _PROBE_MARKER = "!probe"
40 _PROBE_HEADER = f"### Q {_PROBE_MARKER}"
41
42
43 def _normalize_probe_markers(body: str) -> str:
44 """Rewrite `### Q !probe` → `### Q` so the strict parser accepts it.
45
46 Mirrors `dlm.eval.probes._normalize_probe_markers` (kept local to
47 avoid a data → eval import). Probe-marked Q/A pairs still train
48 exactly like plain pairs; the marker is only load-bearing for probe
49 extraction. We drop it silently here rather than leak `!probe:` into
50 the training question text.
51 """
52 if _PROBE_HEADER not in body:
53 return body
54 lines = body.splitlines()
55 rewritten = [("### Q" if line.strip() == _PROBE_HEADER else line) for line in lines]
56 return "\n".join(rewritten)
57
58
59 Row = dict[str, Any]
60
61 _DEFAULT_IMAGE_TOKEN = "<image>"
62 _DEFAULT_AUDIO_TOKEN = "<|AUDIO|>"
63
64
65 def sections_to_rows(
66 sections: list[Section],
67 *,
68 blob_store: BlobStore | None = None,
69 image_token: str = _DEFAULT_IMAGE_TOKEN,
70 audio_token: str = _DEFAULT_AUDIO_TOKEN,
71 ) -> list[Row]:
72 """Flatten every section into its row shape(s), preserving order.
73
74 PROSE sections with empty content are dropped silently — blank
75 regions of a document shouldn't create empty training rows. Empty
76 INSTRUCTION / PREFERENCE bodies are parse errors (handled by the
77 respective section parsers).
78
79 IMAGE / AUDIO sections require `blob_store` (to resolve
80 `media_blob_sha` into bytes) and use `image_token` / `audio_token`
81 as the textual placeholder — the base model's processor expands
82 that placeholder into its fixed token window at collate time.
83 Passing `blob_store=None` with media sections in the input raises
84 `ValueError`.
85 """
86 rows: list[Row] = []
87 for section in sections:
88 rows.extend(
89 _section_to_rows(
90 section,
91 blob_store=blob_store,
92 image_token=image_token,
93 audio_token=audio_token,
94 ),
95 )
96 return rows
97
98
99 def _section_to_rows(
100 section: Section,
101 *,
102 blob_store: BlobStore | None,
103 image_token: str,
104 audio_token: str,
105 ) -> list[Row]:
106 sid = section.section_id
107 tags = dict(section.tags)
108 if section.type is SectionType.PROSE:
109 text = section.content.strip()
110 if not text:
111 return []
112 return [{"text": text, "_dlm_section_id": sid, "_dlm_row_tags": tags}]
113
114 if section.type is SectionType.INSTRUCTION:
115 body = _normalize_probe_markers(section.content)
116 pairs = parse_instruction_body(body, section_id=sid)
117 return [
118 {
119 "messages": [
120 {"role": "user", "content": p.question},
121 {"role": "assistant", "content": p.answer},
122 ],
123 "_dlm_section_id": sid,
124 "_dlm_row_tags": tags,
125 }
126 for p in pairs
127 ]
128
129 if section.type is SectionType.PREFERENCE:
130 triples = parse_preference_body(section.content, section_id=sid)
131 return [
132 {
133 "prompt": t.prompt,
134 "chosen": t.chosen,
135 "rejected": t.rejected,
136 "_dlm_section_id": sid,
137 "_dlm_row_tags": tags,
138 }
139 for t in triples
140 ]
141
142 if section.type is SectionType.IMAGE:
143 return [_image_section_to_row(section, blob_store, image_token, sid, tags)]
144
145 if section.type is SectionType.AUDIO:
146 return [_audio_section_to_row(section, blob_store, audio_token, sid, tags)]
147
148 raise ValueError(f"unknown section type: {section.type!r}") # pragma: no cover
149
150
151 def _image_section_to_row(
152 section: Section,
153 blob_store: BlobStore | None,
154 image_token: str,
155 sid: str,
156 tags: dict[str, str],
157 ) -> Row:
158 """Emit a TRL 1.2-shaped VL row: `{images: [PIL], text: "<image>\\n..."}`.
159
160 The caller (dataset_builder / trainer) owns the BlobStore; sections
161 with `media_blob_sha is None` reach this path only via tests that
162 construct Sections manually without ingest — we refuse them
163 explicitly rather than silently dropping them.
164 """
165 if blob_store is None:
166 raise ValueError(
167 "sections_to_rows: IMAGE section requires a blob_store; "
168 f"section {sid} has media_path={section.media_path!r}",
169 )
170 if section.media_blob_sha is None:
171 raise ValueError(
172 "sections_to_rows: IMAGE section has no media_blob_sha "
173 f"(section {sid} hasn't been ingested through the blob store)",
174 )
175
176 from PIL import Image
177
178 blob_path = blob_store.get(section.media_blob_sha)
179 with Image.open(blob_path) as pil:
180 pil.load()
181 image = pil.convert("RGB")
182
183 caption = section.content.strip()
184 text = f"{image_token}\n{caption}" if caption else image_token
185 return {
186 "images": [image],
187 "text": text,
188 "_dlm_section_id": sid,
189 "_dlm_row_tags": tags,
190 }
191
192
193 def _audio_section_to_row(
194 section: Section,
195 blob_store: BlobStore | None,
196 audio_token: str,
197 sid: str,
198 tags: dict[str, str],
199 ) -> Row:
200 """Emit an audio row: path + sha + transcript-prefixed text.
201
202 Audio rows carry the blob path + sha rather than a decoded
203 waveform so the custom collator (T8) can drive `preprocess_audio`
204 with the content-addressed cache — decoding + feature-extraction
205 is expensive and repeats across epochs. The transcript from the
206 sibling `<stem>.txt` goes into the row's `text` after the
207 placeholder; the custom collator replaces the placeholder with
208 the feature-extractor's `num_audio_tokens` slots before the
209 trainer sees it.
210
211 Audio sections with an empty transcript aren't useful for SFT
212 (no target text to predict) — refuse loudly rather than emit a
213 placeholder-only row that would train the model to produce an
214 empty response.
215 """
216 if blob_store is None:
217 raise ValueError(
218 "sections_to_rows: AUDIO section requires a blob_store; "
219 f"section {sid} has media_path={section.media_path!r}",
220 )
221 if section.media_blob_sha is None:
222 raise ValueError(
223 "sections_to_rows: AUDIO section has no media_blob_sha "
224 f"(section {sid} hasn't been ingested through the blob store)",
225 )
226 transcript = (section.media_transcript or "").strip()
227 if not transcript:
228 raise ValueError(
229 "sections_to_rows: AUDIO section has empty transcript "
230 f"(section {sid}, media_path={section.media_path!r}); "
231 "transcript sibling `<stem>.txt` must be non-empty",
232 )
233
234 blob_path = blob_store.get(section.media_blob_sha)
235 text = f"{audio_token}\n{transcript}"
236 return {
237 "audio_blob_sha": section.media_blob_sha,
238 "audio_path": str(blob_path),
239 "text": text,
240 "_dlm_section_id": sid,
241 "_dlm_row_tags": tags,
242 }