Python · 5567 bytes Raw Blame History
1 """Tag-weighted row expansion — deterministic row repetition.
2
3 Operators declare `weights: {tag_key: {tag_value: float}}` in a
4 `.dlm/training.yaml` to up- or down-scale how often rows with that
5 tag appear in the training corpus. We implement it as *row
6 repetition* rather than per-row loss scaling:
7
8 - weight = 1.0 → row appears once (no-op)
9 - weight = 0.0 → row dropped
10 - weight = 2.0 → row appears twice
11 - weight = 2.5 → row appears twice, plus a deterministic 50%
12 chance of a third copy (seeded by section_id)
13 - weight = 0.5 → row appears with deterministic 50% keep probability
14
15 Multiple tag keys compose multiplicatively: a row tagged
16 `{docstring: true, generated: true}` with
17 `{docstring: {true: 2.0}, generated: {true: 0.5}}` ends up at
18 weight 1.0 (= 2.0 × 0.5).
19
20 Determinism: the keep/extra-copy decision is a hash of
21 `(seed, section_id, fractional_index)`. Same seed + same corpus →
22 same expanded row list, bit-exact. This preserves the determinism
23 guarantee: a cached run and an uncached run on the same weights
24 config produce byte-identical adapter weights.
25
26 **Why row repetition, not per-row loss scaling?** Bit-identity against
27 TRL's `_tokenize` would be lost the moment we subclassed
28 `SFTTrainer.compute_loss` to multiply by a sample-weights tensor —
29 any TRL internal refactor of the loss path becomes a silent
30 correctness bug. Expansion is a dataset-level transform; every
31 downstream layer (pretokenize cache, TRL collator, AdamW) sees a
32 plain list of rows and stays dumb.
33 """
34
35 from __future__ import annotations
36
37 import hashlib
38 from collections.abc import Mapping, Sequence
39 from typing import Any
40
41 Row = dict[str, Any]
42 WeightsMap = Mapping[str, Mapping[str, float]]
43
44
45 def merge_weights_maps(maps: Sequence[WeightsMap]) -> dict[str, dict[str, float]]:
46 """Merge a shallowest-to-deepest sequence of weights maps.
47
48 Deeper entries override shallower ones at the `(tag_key, tag_value)`
49 grain, matching the nearest-ancestor semantics `.dlm/training.yaml`
50 already uses for `metadata` and `exclude`. An empty sequence
51 returns `{}`.
52 """
53 merged: dict[str, dict[str, float]] = {}
54 for weights in maps:
55 for tag_key, inner in weights.items():
56 dst = merged.setdefault(tag_key, {})
57 for tag_value, scale in inner.items():
58 dst[tag_value] = scale
59 return merged
60
61
62 def resolve_row_weight(row_tags: Mapping[str, str], weights: WeightsMap) -> float:
63 """Compose the effective weight for a row from its tags + weights map.
64
65 Missing tag keys and unmatched tag values contribute 1.0 (no
66 scaling). Matching `(tag_key, tag_value)` entries multiply in.
67 Order-independent.
68 """
69 weight = 1.0
70 for tag_key, tag_value in row_tags.items():
71 inner = weights.get(tag_key)
72 if inner is None:
73 continue
74 scale = inner.get(tag_value)
75 if scale is None:
76 continue
77 weight *= scale
78 return weight
79
80
81 def _keep_fraction(section_id: str, seed: int, fractional: float) -> bool:
82 """Deterministic Bernoulli: True with probability `fractional`.
83
84 Uses BLAKE2b over `(seed, section_id)` — cheap, collision-
85 resistant, and reproducible across platforms. The section_id is
86 stable under the content-addressed store, so the keep/drop
87 decision for a given row depends only on seed + content, never
88 on row position.
89 """
90 if fractional <= 0.0:
91 return False
92 if fractional >= 1.0:
93 return True
94 h = hashlib.blake2b(f"{seed}:{section_id}".encode(), digest_size=8).digest()
95 # Map the first 8 bytes to [0, 1) — integer / 2**64.
96 roll = int.from_bytes(h, "big") / float(1 << 64)
97 return roll < fractional
98
99
100 def expand_rows_by_weight(
101 rows: Sequence[Row],
102 weights: WeightsMap,
103 *,
104 seed: int,
105 ) -> list[Row]:
106 """Return a new row list where each input row is repeated (or dropped)
107 per its composed weight.
108
109 A row without a `_dlm_row_tags` key gets weight 1.0 (untouched).
110 An empty `weights` map is a no-op (returns a shallow copy of
111 `rows`). Section-ID preservation means the replay corpus still
112 tracks per-row identity — the N copies of a repeated row share
113 a section_id, which matches the replay semantics of retraining on
114 the same content N times.
115 """
116 if not weights:
117 return list(rows)
118
119 expanded: list[Row] = []
120 for row in rows:
121 row_tags = row.get("_dlm_row_tags") or {}
122 weight = resolve_row_weight(row_tags, weights)
123 if weight <= 0.0:
124 continue
125 integer_copies = int(weight)
126 fractional = weight - integer_copies
127 for _ in range(integer_copies):
128 expanded.append(row)
129 if fractional > 0.0:
130 section_id = str(row.get("_dlm_section_id", ""))
131 if _keep_fraction(section_id, seed, fractional):
132 expanded.append(row)
133 return expanded
134
135
136 def weight_distribution(
137 rows: Sequence[Row],
138 ) -> dict[str, dict[str, int]]:
139 """Count original rows per `(tag_key, tag_value)` for summary reporting.
140
141 Takes the pre-expansion row list so users can audit how many rows
142 were candidates for each rule, independent of how many copies
143 the expansion produced.
144 """
145 dist: dict[str, dict[str, int]] = {}
146 for row in rows:
147 row_tags = row.get("_dlm_row_tags") or {}
148 for tag_key, tag_value in row_tags.items():
149 inner = dist.setdefault(tag_key, {})
150 inner[tag_value] = inner.get(tag_value, 0) + 1
151 return dist