Python · 2984 bytes Raw Blame History
1 """Deterministic train / val split keyed on (seed, section_id).
2
3 The invariant: adding a section to the `.dlm` does NOT reshuffle the
4 existing assignments. Every row's train-vs-val fate is a pure function
5 of `(seed, row["_dlm_section_id"], sub_index)` — the sub-index is the
6 row's position within its section (so a single INSTRUCTION block with
7 ten Q/A pairs distributes those pairs across the split independently).
8
9 The split is computed by hashing `(seed, section_id, sub_index)` and
10 comparing against `val_frac * 2**64`. This is stable across Python
11 versions (we use `hashlib.sha256` rather than `hash()`).
12 """
13
14 from __future__ import annotations
15
16 import hashlib
17 from collections import defaultdict
18 from typing import TYPE_CHECKING, Any
19
20 if TYPE_CHECKING:
21 from datasets import Dataset
22
23 Row = dict[str, Any]
24
25
26 def split(
27 rows: list[Row],
28 *,
29 val_frac: float,
30 seed: int,
31 ) -> tuple[Dataset, Dataset]:
32 """Partition `rows` into (train_ds, val_ds) datasets.
33
34 `val_frac` must be in (0, 1). `seed` is combined with each row's
35 `_dlm_section_id` + its in-section sub-index to produce a stable
36 assignment.
37
38 Raises `ValueError` if `val_frac` is outside (0, 1) or if any row
39 lacks `_dlm_section_id`.
40 """
41 from datasets import Dataset
42
43 if not 0.0 < val_frac < 1.0:
44 raise ValueError(f"val_frac must be in (0, 1), got {val_frac!r}")
45
46 train_rows: list[Row] = []
47 val_rows: list[Row] = []
48 per_section_index: dict[str, int] = defaultdict(int)
49
50 threshold = int(val_frac * (1 << 64))
51
52 for row in rows:
53 sid = row.get("_dlm_section_id")
54 if not isinstance(sid, str) or not sid:
55 raise ValueError(
56 "every row must carry a string `_dlm_section_id` (did you skip sections_to_rows?)"
57 )
58 sub_index = per_section_index[sid]
59 per_section_index[sid] += 1
60 if _assigns_to_val(seed=seed, section_id=sid, sub_index=sub_index, threshold=threshold):
61 val_rows.append(row)
62 else:
63 train_rows.append(row)
64
65 # `Dataset.from_list` infers schema from row[0] only — mixed-shape
66 # rows lose keys that don't appear in the first dict. Unify the
67 # key-set across BOTH buckets first so train+val share the same
68 # schema and no field silently drops out.
69 all_keys: set[str] = set()
70 for row in rows:
71 all_keys.update(row.keys())
72 _unify_keys(train_rows, all_keys)
73 _unify_keys(val_rows, all_keys)
74
75 return Dataset.from_list(train_rows), Dataset.from_list(val_rows)
76
77
78 def _unify_keys(rows: list[Row], keys: set[str]) -> None:
79 for row in rows:
80 for k in keys:
81 row.setdefault(k, None)
82
83
84 def _assigns_to_val(*, seed: int, section_id: str, sub_index: int, threshold: int) -> bool:
85 key = f"{seed}\x00{section_id}\x00{sub_index}".encode()
86 digest = hashlib.sha256(key).digest()[:8]
87 bucket = int.from_bytes(digest, byteorder="big", signed=False)
88 return bucket < threshold