Python · 3039 bytes Raw Blame History
1 """End-to-end: parsed `.dlm` sections → (train_ds, val_ds).
2
3 This is the single entry point the trainer calls. It:
4
5 1. Flattens `sections` to dict rows via `sections_to_rows`.
6 2. Optionally concatenates a replay-corpus row iterable (we just
7 accept an iterable here to keep the dependency one-directional).
8 3. Splits into train / val via the deterministic splitter.
9
10 The split is keyed on each row's `_dlm_section_id` + sub-index, so
11 replay rows must also carry a stable `_dlm_section_id` — the corpus
12 reader stamps one derived from the originating document's version.
13 """
14
15 from __future__ import annotations
16
17 from collections.abc import Iterable, Mapping
18 from typing import TYPE_CHECKING, Any
19
20 from dlm.data.sections_to_rows import sections_to_rows
21 from dlm.data.splitter import split
22 from dlm.data.weighted_rows import expand_rows_by_weight
23 from dlm.doc.sections import Section
24
25 if TYPE_CHECKING:
26 from datasets import Dataset
27
28 from dlm.store.blobs import BlobStore
29
30 Row = dict[str, Any]
31
32
33 def build_dataset(
34 sections: list[Section],
35 *,
36 val_frac: float = 0.1,
37 seed: int,
38 replay_rows: Iterable[Row] | None = None,
39 weights: Mapping[str, Mapping[str, float]] | None = None,
40 blob_store: BlobStore | None = None,
41 image_token: str = "<image>",
42 audio_token: str = "<|AUDIO|>",
43 ) -> tuple[Dataset, Dataset]:
44 """Build a (train, val) `Dataset` pair from parsed `.dlm` sections.
45
46 `seed` is required (not defaulted) so the split is always traceable
47 to a manifest entry; `val_frac=0.1` matches the current default.
48
49 `weights`, when non-empty, expands rows by `(tag_key, tag_value)`
50 multipliers before the train/val split — integer factors duplicate
51 rows, fractional factors drive a deterministic per-section keep
52 decision. The expansion applies to both in-document and replay
53 rows so retention behaves uniformly.
54
55 `blob_store` + `image_token` + `audio_token` flow through to
56 `sections_to_rows` for media-section emission. Callers with
57 vision-language or audio-language bases must supply the store;
58 text-only documents leave the defaults.
59 """
60 rows = sections_to_rows(
61 sections,
62 blob_store=blob_store,
63 image_token=image_token,
64 audio_token=audio_token,
65 )
66 if replay_rows is not None:
67 rows.extend(r for r in replay_rows if not _is_preference_row(r))
68
69 if not rows:
70 raise ValueError(
71 "no trainable rows — document has no non-empty PROSE/INSTRUCTION/PREFERENCE sections"
72 )
73
74 if weights:
75 rows = expand_rows_by_weight(rows, weights, seed=seed)
76 if not rows:
77 raise ValueError(
78 "weights dropped every row — check `training.yaml` weights for zeros across all tag values"
79 )
80
81 return split(rows, val_frac=val_frac, seed=seed)
82
83
84 def _is_preference_row(row: Row) -> bool:
85 return (
86 row.get("prompt") is not None
87 and row.get("chosen") is not None
88 and row.get("rejected") is not None
89 )