Python · 5987 bytes Raw Blame History
1 """Tag-weighted row expansion — math + determinism + edge cases."""
2
3 from __future__ import annotations
4
5 from dlm.data.weighted_rows import (
6 _keep_fraction,
7 expand_rows_by_weight,
8 merge_weights_maps,
9 resolve_row_weight,
10 weight_distribution,
11 )
12
13
14 def _row(section_id: str, tags: dict[str, str] | None = None) -> dict:
15 return {
16 "text": f"body-{section_id}",
17 "_dlm_section_id": section_id,
18 "_dlm_row_tags": tags or {},
19 }
20
21
22 class TestResolveRowWeight:
23 def test_empty_weights_gives_unity(self) -> None:
24 assert resolve_row_weight({"a": "b"}, {}) == 1.0
25
26 def test_empty_tags_gives_unity(self) -> None:
27 assert resolve_row_weight({}, {"a": {"b": 2.0}}) == 1.0
28
29 def test_single_matching_tag_scales(self) -> None:
30 assert resolve_row_weight({"docstring": "true"}, {"docstring": {"true": 2.5}}) == 2.5
31
32 def test_multiple_keys_multiply(self) -> None:
33 weights = {"a": {"x": 2.0}, "b": {"y": 0.5}}
34 assert resolve_row_weight({"a": "x", "b": "y"}, weights) == 1.0
35
36 def test_unmatched_value_does_not_scale(self) -> None:
37 assert resolve_row_weight({"a": "z"}, {"a": {"x": 2.0}}) == 1.0
38
39 def test_unmatched_key_does_not_scale(self) -> None:
40 assert resolve_row_weight({"b": "x"}, {"a": {"x": 2.0}}) == 1.0
41
42
43 class TestExpandRowsByWeight:
44 def test_empty_weights_returns_shallow_copy(self) -> None:
45 rows = [_row("01"), _row("02")]
46 out = expand_rows_by_weight(rows, {}, seed=42)
47 assert out == rows
48 assert out is not rows # shallow copy
49
50 def test_weight_one_is_noop(self) -> None:
51 rows = [_row("01", {"k": "v"})]
52 out = expand_rows_by_weight(rows, {"k": {"v": 1.0}}, seed=42)
53 assert len(out) == 1
54
55 def test_integer_weight_repeats(self) -> None:
56 rows = [_row("01", {"k": "v"})]
57 out = expand_rows_by_weight(rows, {"k": {"v": 3.0}}, seed=42)
58 assert len(out) == 3
59 # All copies share the same section_id.
60 assert {r["_dlm_section_id"] for r in out} == {"01"}
61
62 def test_zero_weight_drops(self) -> None:
63 rows = [_row("01", {"k": "v"}), _row("02", {"other": "x"})]
64 out = expand_rows_by_weight(rows, {"k": {"v": 0.0}}, seed=42)
65 # Row 02 is untagged for `k` so it keeps weight 1.
66 assert len(out) == 1
67 assert out[0]["_dlm_section_id"] == "02"
68
69 def test_fractional_weight_is_deterministic(self) -> None:
70 rows = [_row(f"{i:02d}", {"k": "v"}) for i in range(100)]
71 out1 = expand_rows_by_weight(rows, {"k": {"v": 0.5}}, seed=42)
72 out2 = expand_rows_by_weight(rows, {"k": {"v": 0.5}}, seed=42)
73 assert [r["_dlm_section_id"] for r in out1] == [r["_dlm_section_id"] for r in out2]
74
75 def test_fractional_weight_approximates_probability(self) -> None:
76 rows = [_row(f"{i:04d}", {"k": "v"}) for i in range(1000)]
77 out = expand_rows_by_weight(rows, {"k": {"v": 0.5}}, seed=42)
78 # 50% keep rate with 1000 rows should land within ±10% of 500.
79 assert 450 <= len(out) <= 550
80
81 def test_weight_between_one_and_two_includes_integer_plus_fractional(self) -> None:
82 rows = [_row(f"{i:04d}", {"k": "v"}) for i in range(1000)]
83 out = expand_rows_by_weight(rows, {"k": {"v": 1.5}}, seed=42)
84 # Every row gets 1 copy unconditionally plus ~50% get a 2nd.
85 assert 1450 <= len(out) <= 1550
86
87 def test_different_seeds_yield_different_expansions(self) -> None:
88 rows = [_row(f"{i:02d}", {"k": "v"}) for i in range(100)]
89 out1 = expand_rows_by_weight(rows, {"k": {"v": 0.5}}, seed=42)
90 out2 = expand_rows_by_weight(rows, {"k": {"v": 0.5}}, seed=43)
91 # Not byte-identical — different seeds drive different Bernoulli rolls.
92 ids1 = [r["_dlm_section_id"] for r in out1]
93 ids2 = [r["_dlm_section_id"] for r in out2]
94 assert ids1 != ids2
95
96 def test_rows_without_tags_get_unity_weight(self) -> None:
97 rows = [_row("01", {}), _row("02", {"k": "v"})]
98 out = expand_rows_by_weight(rows, {"k": {"v": 3.0}}, seed=42)
99 # Row 01 = 1 copy (unity). Row 02 = 3 copies.
100 assert len(out) == 4
101
102 def test_multiplicative_composition(self) -> None:
103 rows = [_row("01", {"a": "x", "b": "y"})]
104 weights = {"a": {"x": 2.0}, "b": {"y": 3.0}}
105 out = expand_rows_by_weight(rows, weights, seed=42)
106 # 2.0 × 3.0 = 6 copies.
107 assert len(out) == 6
108
109
110 class TestMergeWeightsMaps:
111 def test_empty_sequence_returns_empty_map(self) -> None:
112 assert merge_weights_maps([]) == {}
113
114 def test_deeper_entries_override_shallower_ones(self) -> None:
115 merged = merge_weights_maps(
116 [
117 {"lang": {"py": 2.0, "rs": 1.5}, "gen": {"true": 0.5}},
118 {"lang": {"py": 3.0}, "new": {"x": 4.0}},
119 ]
120 )
121 assert merged == {
122 "lang": {"py": 3.0, "rs": 1.5},
123 "gen": {"true": 0.5},
124 "new": {"x": 4.0},
125 }
126
127
128 class TestKeepFraction:
129 def test_non_positive_fraction_never_keeps(self) -> None:
130 assert _keep_fraction("sid", seed=42, fractional=0.0) is False
131
132 def test_fraction_at_or_above_one_always_keeps(self) -> None:
133 assert _keep_fraction("sid", seed=42, fractional=1.0) is True
134
135
136 class TestWeightDistribution:
137 def test_empty_rows_empty_dist(self) -> None:
138 assert weight_distribution([]) == {}
139
140 def test_untagged_rows_produce_empty_dist(self) -> None:
141 rows = [_row("01"), _row("02")]
142 assert weight_distribution(rows) == {}
143
144 def test_counts_per_tag_value(self) -> None:
145 rows = [
146 _row("01", {"lang": "py", "gen": "true"}),
147 _row("02", {"lang": "py", "gen": "false"}),
148 _row("03", {"lang": "rs", "gen": "false"}),
149 ]
150 dist = weight_distribution(rows)
151 assert dist == {
152 "lang": {"py": 2, "rs": 1},
153 "gen": {"true": 1, "false": 2},
154 }