| 1 |
"""Tag-weighted row expansion — deterministic row repetition. |
| 2 |
|
| 3 |
Operators declare `weights: {tag_key: {tag_value: float}}` in a |
| 4 |
`.dlm/training.yaml` to up- or down-scale how often rows with that |
| 5 |
tag appear in the training corpus. We implement it as *row |
| 6 |
repetition* rather than per-row loss scaling: |
| 7 |
|
| 8 |
- weight = 1.0 → row appears once (no-op) |
| 9 |
- weight = 0.0 → row dropped |
| 10 |
- weight = 2.0 → row appears twice |
| 11 |
- weight = 2.5 → row appears twice, plus a deterministic 50% |
| 12 |
chance of a third copy (seeded by section_id) |
| 13 |
- weight = 0.5 → row appears with deterministic 50% keep probability |
| 14 |
|
| 15 |
Multiple tag keys compose multiplicatively: a row tagged |
| 16 |
`{docstring: true, generated: true}` with |
| 17 |
`{docstring: {true: 2.0}, generated: {true: 0.5}}` ends up at |
| 18 |
weight 1.0 (= 2.0 × 0.5). |
| 19 |
|
| 20 |
Determinism: the keep/extra-copy decision is a hash of |
| 21 |
`(seed, section_id, fractional_index)`. Same seed + same corpus → |
| 22 |
same expanded row list, bit-exact. This preserves the determinism |
| 23 |
guarantee: a cached run and an uncached run on the same weights |
| 24 |
config produce byte-identical adapter weights. |
| 25 |
|
| 26 |
**Why row repetition, not per-row loss scaling?** Bit-identity against |
| 27 |
TRL's `_tokenize` would be lost the moment we subclassed |
| 28 |
`SFTTrainer.compute_loss` to multiply by a sample-weights tensor — |
| 29 |
any TRL internal refactor of the loss path becomes a silent |
| 30 |
correctness bug. Expansion is a dataset-level transform; every |
| 31 |
downstream layer (pretokenize cache, TRL collator, AdamW) sees a |
| 32 |
plain list of rows and stays dumb. |
| 33 |
""" |
| 34 |
|
| 35 |
from __future__ import annotations |
| 36 |
|
| 37 |
import hashlib |
| 38 |
from collections.abc import Mapping, Sequence |
| 39 |
from typing import Any |
| 40 |
|
| 41 |
Row = dict[str, Any] |
| 42 |
WeightsMap = Mapping[str, Mapping[str, float]] |
| 43 |
|
| 44 |
|
| 45 |
def merge_weights_maps(maps: Sequence[WeightsMap]) -> dict[str, dict[str, float]]: |
| 46 |
"""Merge a shallowest-to-deepest sequence of weights maps. |
| 47 |
|
| 48 |
Deeper entries override shallower ones at the `(tag_key, tag_value)` |
| 49 |
grain, matching the nearest-ancestor semantics `.dlm/training.yaml` |
| 50 |
already uses for `metadata` and `exclude`. An empty sequence |
| 51 |
returns `{}`. |
| 52 |
""" |
| 53 |
merged: dict[str, dict[str, float]] = {} |
| 54 |
for weights in maps: |
| 55 |
for tag_key, inner in weights.items(): |
| 56 |
dst = merged.setdefault(tag_key, {}) |
| 57 |
for tag_value, scale in inner.items(): |
| 58 |
dst[tag_value] = scale |
| 59 |
return merged |
| 60 |
|
| 61 |
|
| 62 |
def resolve_row_weight(row_tags: Mapping[str, str], weights: WeightsMap) -> float: |
| 63 |
"""Compose the effective weight for a row from its tags + weights map. |
| 64 |
|
| 65 |
Missing tag keys and unmatched tag values contribute 1.0 (no |
| 66 |
scaling). Matching `(tag_key, tag_value)` entries multiply in. |
| 67 |
Order-independent. |
| 68 |
""" |
| 69 |
weight = 1.0 |
| 70 |
for tag_key, tag_value in row_tags.items(): |
| 71 |
inner = weights.get(tag_key) |
| 72 |
if inner is None: |
| 73 |
continue |
| 74 |
scale = inner.get(tag_value) |
| 75 |
if scale is None: |
| 76 |
continue |
| 77 |
weight *= scale |
| 78 |
return weight |
| 79 |
|
| 80 |
|
| 81 |
def _keep_fraction(section_id: str, seed: int, fractional: float) -> bool: |
| 82 |
"""Deterministic Bernoulli: True with probability `fractional`. |
| 83 |
|
| 84 |
Uses BLAKE2b over `(seed, section_id)` — cheap, collision- |
| 85 |
resistant, and reproducible across platforms. The section_id is |
| 86 |
stable under the content-addressed store, so the keep/drop |
| 87 |
decision for a given row depends only on seed + content, never |
| 88 |
on row position. |
| 89 |
""" |
| 90 |
if fractional <= 0.0: |
| 91 |
return False |
| 92 |
if fractional >= 1.0: |
| 93 |
return True |
| 94 |
h = hashlib.blake2b(f"{seed}:{section_id}".encode(), digest_size=8).digest() |
| 95 |
# Map the first 8 bytes to [0, 1) — integer / 2**64. |
| 96 |
roll = int.from_bytes(h, "big") / float(1 << 64) |
| 97 |
return roll < fractional |
| 98 |
|
| 99 |
|
| 100 |
def expand_rows_by_weight( |
| 101 |
rows: Sequence[Row], |
| 102 |
weights: WeightsMap, |
| 103 |
*, |
| 104 |
seed: int, |
| 105 |
) -> list[Row]: |
| 106 |
"""Return a new row list where each input row is repeated (or dropped) |
| 107 |
per its composed weight. |
| 108 |
|
| 109 |
A row without a `_dlm_row_tags` key gets weight 1.0 (untouched). |
| 110 |
An empty `weights` map is a no-op (returns a shallow copy of |
| 111 |
`rows`). Section-ID preservation means the replay corpus still |
| 112 |
tracks per-row identity — the N copies of a repeated row share |
| 113 |
a section_id, which matches the replay semantics of retraining on |
| 114 |
the same content N times. |
| 115 |
""" |
| 116 |
if not weights: |
| 117 |
return list(rows) |
| 118 |
|
| 119 |
expanded: list[Row] = [] |
| 120 |
for row in rows: |
| 121 |
row_tags = row.get("_dlm_row_tags") or {} |
| 122 |
weight = resolve_row_weight(row_tags, weights) |
| 123 |
if weight <= 0.0: |
| 124 |
continue |
| 125 |
integer_copies = int(weight) |
| 126 |
fractional = weight - integer_copies |
| 127 |
for _ in range(integer_copies): |
| 128 |
expanded.append(row) |
| 129 |
if fractional > 0.0: |
| 130 |
section_id = str(row.get("_dlm_section_id", "")) |
| 131 |
if _keep_fraction(section_id, seed, fractional): |
| 132 |
expanded.append(row) |
| 133 |
return expanded |
| 134 |
|
| 135 |
|
| 136 |
def weight_distribution( |
| 137 |
rows: Sequence[Row], |
| 138 |
) -> dict[str, dict[str, int]]: |
| 139 |
"""Count original rows per `(tag_key, tag_value)` for summary reporting. |
| 140 |
|
| 141 |
Takes the pre-expansion row list so users can audit how many rows |
| 142 |
were candidates for each rule, independent of how many copies |
| 143 |
the expansion produced. |
| 144 |
""" |
| 145 |
dist: dict[str, dict[str, int]] = {} |
| 146 |
for row in rows: |
| 147 |
row_tags = row.get("_dlm_row_tags") or {} |
| 148 |
for tag_key, tag_value in row_tags.items(): |
| 149 |
inner = dist.setdefault(tag_key, {}) |
| 150 |
inner[tag_value] = inner.get(tag_value, 0) + 1 |
| 151 |
return dist |