Python · 5041 bytes Raw Blame History
1 """Sampler determinism, recency bias, and contract."""
2
3 from __future__ import annotations
4
5 import random
6 from datetime import datetime, timedelta
7
8 import pytest
9
10 from dlm.replay.errors import SamplerError
11 from dlm.replay.models import IndexEntry
12 from dlm.replay.sampler import _weighted_reservoir, sample
13
14 _NOW = datetime(2026, 4, 18)
15
16
17 def _entries(n: int, *, age_step_days: float = 1.0) -> list[IndexEntry]:
18 return [
19 IndexEntry(
20 section_id=f"{i:016x}",
21 byte_offset=i * 100,
22 length=100,
23 added_at=_NOW - timedelta(days=i * age_step_days),
24 )
25 for i in range(n)
26 ]
27
28
29 class TestDeterminism:
30 def test_same_seed_same_sample(self) -> None:
31 entries = _entries(100)
32 p1 = sample(entries, k=10, now=_NOW, rng=random.Random(42))
33 p2 = sample(entries, k=10, now=_NOW, rng=random.Random(42))
34 assert [e.section_id for e in p1] == [e.section_id for e in p2]
35
36 def test_different_seed_different_sample(self) -> None:
37 entries = _entries(100)
38 p1 = sample(entries, k=10, now=_NOW, rng=random.Random(0))
39 p2 = sample(entries, k=10, now=_NOW, rng=random.Random(1))
40 assert [e.section_id for e in p1] != [e.section_id for e in p2]
41
42
43 class TestBoundaries:
44 def test_empty_entries(self) -> None:
45 assert sample([], k=5, now=_NOW, rng=random.Random(0)) == []
46
47 def test_k_zero_returns_empty(self) -> None:
48 assert sample(_entries(10), k=0, now=_NOW, rng=random.Random(0)) == []
49
50 def test_k_negative_rejected(self) -> None:
51 with pytest.raises(SamplerError, match="non-negative"):
52 sample(_entries(10), k=-1, now=_NOW, rng=random.Random(0))
53
54 def test_k_exceeds_size_returns_all(self) -> None:
55 entries = _entries(5)
56 out = sample(entries, k=10, now=_NOW, rng=random.Random(0))
57 assert len(out) == 5
58 assert [e.section_id for e in out] == sorted(e.section_id for e in entries)
59
60 def test_unknown_scheme_raises(self) -> None:
61 with pytest.raises(SamplerError, match="scheme"):
62 sample(
63 _entries(5),
64 k=3,
65 now=_NOW,
66 rng=random.Random(0),
67 scheme="frequency", # type: ignore[arg-type]
68 )
69
70
71 class TestRecencyBias:
72 def test_recent_sections_more_likely(self) -> None:
73 """Average age of recency picks must be younger than uniform over many seeds."""
74 entries = _entries(200, age_step_days=2.0)
75 recency_total_age = 0
76 uniform_total_age = 0
77 trials = 30
78 for seed in range(trials):
79 rec = sample(entries, k=20, now=_NOW, rng=random.Random(seed), scheme="recency")
80 uni = sample(entries, k=20, now=_NOW, rng=random.Random(seed), scheme="uniform")
81 recency_total_age += sum((_NOW - e.added_at).days for e in rec)
82 uniform_total_age += sum((_NOW - e.added_at).days for e in uni)
83 assert recency_total_age < uniform_total_age
84
85 def test_zero_weight_entry_never_sampled(self) -> None:
86 entries = [
87 IndexEntry(
88 section_id="a" * 16,
89 byte_offset=0,
90 length=100,
91 added_at=_NOW,
92 weight=1.0,
93 ),
94 IndexEntry(
95 section_id="b" * 16,
96 byte_offset=100,
97 length=100,
98 added_at=_NOW,
99 weight=0.0,
100 ),
101 ]
102 # k=1, 20 trials; weight=0 must never appear.
103 for seed in range(20):
104 picked = sample(entries, k=1, now=_NOW, rng=random.Random(seed))
105 assert picked[0].section_id == "a" * 16
106
107
108 class TestStableOrdering:
109 def test_stable_input_ordering_irrespective_of_list_order(self) -> None:
110 """Sampling depends only on RNG + entries content, not on list order."""
111 entries_a = _entries(50)
112 entries_b = list(reversed(entries_a))
113 p1 = sample(entries_a, k=5, now=_NOW, rng=random.Random(7), scheme="uniform")
114 p2 = sample(entries_b, k=5, now=_NOW, rng=random.Random(7), scheme="uniform")
115 assert [e.section_id for e in p1] == [e.section_id for e in p2]
116
117
118 class TestReservoirEdgeCases:
119 def test_zero_random_draw_retries_and_falls_back_to_tiny_positive(self) -> None:
120 entries = _entries(2)
121
122 class _ZeroThenHalfRng:
123 def __init__(self) -> None:
124 self._values = iter([0.0, 0.0, 0.5, 0.5])
125
126 def random(self) -> float:
127 return next(self._values)
128
129 picked = sample(entries, k=1, now=_NOW, rng=_ZeroThenHalfRng(), scheme="uniform")
130 assert len(picked) == 1
131
132 def test_nonpositive_weight_entries_are_skipped(self) -> None:
133 entries = _entries(3)
134 picked = _weighted_reservoir(
135 entries,
136 weights=[1.0, 0.0, -1.0],
137 k=3,
138 rng=random.Random(0),
139 )
140 assert [entry.section_id for entry in picked] == [entries[0].section_id]