| 1 |
"""Deterministic train / val split keyed on (seed, section_id). |
| 2 |
|
| 3 |
The invariant: adding a section to the `.dlm` does NOT reshuffle the |
| 4 |
existing assignments. Every row's train-vs-val fate is a pure function |
| 5 |
of `(seed, row["_dlm_section_id"], sub_index)` — the sub-index is the |
| 6 |
row's position within its section (so a single INSTRUCTION block with |
| 7 |
ten Q/A pairs distributes those pairs across the split independently). |
| 8 |
|
| 9 |
The split is computed by hashing `(seed, section_id, sub_index)` and |
| 10 |
comparing against `val_frac * 2**64`. This is stable across Python |
| 11 |
versions (we use `hashlib.sha256` rather than `hash()`). |
| 12 |
""" |
| 13 |
|
| 14 |
from __future__ import annotations |
| 15 |
|
| 16 |
import hashlib |
| 17 |
from collections import defaultdict |
| 18 |
from typing import TYPE_CHECKING, Any |
| 19 |
|
| 20 |
if TYPE_CHECKING: |
| 21 |
from datasets import Dataset |
| 22 |
|
| 23 |
Row = dict[str, Any] |
| 24 |
|
| 25 |
|
| 26 |
def split( |
| 27 |
rows: list[Row], |
| 28 |
*, |
| 29 |
val_frac: float, |
| 30 |
seed: int, |
| 31 |
) -> tuple[Dataset, Dataset]: |
| 32 |
"""Partition `rows` into (train_ds, val_ds) datasets. |
| 33 |
|
| 34 |
`val_frac` must be in (0, 1). `seed` is combined with each row's |
| 35 |
`_dlm_section_id` + its in-section sub-index to produce a stable |
| 36 |
assignment. |
| 37 |
|
| 38 |
Raises `ValueError` if `val_frac` is outside (0, 1) or if any row |
| 39 |
lacks `_dlm_section_id`. |
| 40 |
""" |
| 41 |
from datasets import Dataset |
| 42 |
|
| 43 |
if not 0.0 < val_frac < 1.0: |
| 44 |
raise ValueError(f"val_frac must be in (0, 1), got {val_frac!r}") |
| 45 |
|
| 46 |
train_rows: list[Row] = [] |
| 47 |
val_rows: list[Row] = [] |
| 48 |
per_section_index: dict[str, int] = defaultdict(int) |
| 49 |
|
| 50 |
threshold = int(val_frac * (1 << 64)) |
| 51 |
|
| 52 |
for row in rows: |
| 53 |
sid = row.get("_dlm_section_id") |
| 54 |
if not isinstance(sid, str) or not sid: |
| 55 |
raise ValueError( |
| 56 |
"every row must carry a string `_dlm_section_id` (did you skip sections_to_rows?)" |
| 57 |
) |
| 58 |
sub_index = per_section_index[sid] |
| 59 |
per_section_index[sid] += 1 |
| 60 |
if _assigns_to_val(seed=seed, section_id=sid, sub_index=sub_index, threshold=threshold): |
| 61 |
val_rows.append(row) |
| 62 |
else: |
| 63 |
train_rows.append(row) |
| 64 |
|
| 65 |
# `Dataset.from_list` infers schema from row[0] only — mixed-shape |
| 66 |
# rows lose keys that don't appear in the first dict. Unify the |
| 67 |
# key-set across BOTH buckets first so train+val share the same |
| 68 |
# schema and no field silently drops out. |
| 69 |
all_keys: set[str] = set() |
| 70 |
for row in rows: |
| 71 |
all_keys.update(row.keys()) |
| 72 |
_unify_keys(train_rows, all_keys) |
| 73 |
_unify_keys(val_rows, all_keys) |
| 74 |
|
| 75 |
return Dataset.from_list(train_rows), Dataset.from_list(val_rows) |
| 76 |
|
| 77 |
|
| 78 |
def _unify_keys(rows: list[Row], keys: set[str]) -> None: |
| 79 |
for row in rows: |
| 80 |
for k in keys: |
| 81 |
row.setdefault(k, None) |
| 82 |
|
| 83 |
|
| 84 |
def _assigns_to_val(*, seed: int, section_id: str, sub_index: int, threshold: int) -> bool: |
| 85 |
key = f"{seed}\x00{section_id}\x00{sub_index}".encode() |
| 86 |
digest = hashlib.sha256(key).digest()[:8] |
| 87 |
bucket = int.from_bytes(digest, byteorder="big", signed=False) |
| 88 |
return bucket < threshold |