Go · 5141 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 package ratelimit
4
5 import (
6 "context"
7 "net/netip"
8 "testing"
9 "time"
10
11 "github.com/tenseleyFlow/shithub/internal/testing/dbtest"
12 )
13
14 func TestAllow_UnderLimitThenBlocked(t *testing.T) {
15 t.Parallel()
16 l := New(dbtest.NewTestDB(t))
17 ctx := context.Background()
18 p := Policy{Scope: "test:underthen", Max: 3, Window: time.Hour}
19 key := "k1"
20
21 for i := 1; i <= 3; i++ {
22 d, err := l.Allow(ctx, p, key)
23 if err != nil {
24 t.Fatalf("hit %d err: %v", i, err)
25 }
26 if !d.Allowed {
27 t.Fatalf("hit %d: expected Allowed; got %v", i, d)
28 }
29 if d.Limit != 3 {
30 t.Errorf("hit %d Limit = %d; want 3", i, d.Limit)
31 }
32 if d.Remaining != 3-i {
33 t.Errorf("hit %d Remaining = %d; want %d", i, d.Remaining, 3-i)
34 }
35 }
36 d, err := l.Allow(ctx, p, key)
37 if err != nil {
38 t.Fatalf("over-limit err: %v", err)
39 }
40 if d.Allowed {
41 t.Fatalf("4th hit allowed; want blocked")
42 }
43 if d.RetryAfter <= 0 {
44 t.Fatalf("RetryAfter = %v; want > 0", d.RetryAfter)
45 }
46 }
47
48 func TestAllow_DistinctKeysAreIndependent(t *testing.T) {
49 t.Parallel()
50 l := New(dbtest.NewTestDB(t))
51 ctx := context.Background()
52 p := Policy{Scope: "test:distinct", Max: 1, Window: time.Hour}
53
54 if d, _ := l.Allow(ctx, p, "alice"); !d.Allowed {
55 t.Fatalf("alice first hit blocked: %v", d)
56 }
57 if d, _ := l.Allow(ctx, p, "bob"); !d.Allowed {
58 t.Fatalf("bob first hit blocked: %v", d)
59 }
60 if d, _ := l.Allow(ctx, p, "alice"); d.Allowed {
61 t.Fatalf("alice second hit allowed; want blocked")
62 }
63 }
64
65 func TestAcquireLease_BlocksUntilRelease(t *testing.T) {
66 t.Parallel()
67 l := New(dbtest.NewTestDB(t))
68 ctx := context.Background()
69 p := Policy{Scope: "test:lease", Max: 2, Window: time.Hour}
70
71 lease1, d, err := l.AcquireLease(ctx, p, "alice")
72 if err != nil {
73 t.Fatalf("lease1 err: %v", err)
74 }
75 if !d.Allowed || d.Remaining != 1 {
76 t.Fatalf("lease1 decision=%+v", d)
77 }
78 lease2, d, err := l.AcquireLease(ctx, p, "alice")
79 if err != nil {
80 t.Fatalf("lease2 err: %v", err)
81 }
82 if !d.Allowed || d.Remaining != 0 {
83 t.Fatalf("lease2 decision=%+v", d)
84 }
85 lease3, d, err := l.AcquireLease(ctx, p, "alice")
86 if err != nil {
87 t.Fatalf("lease3 err: %v", err)
88 }
89 if lease3 != nil || d.Allowed {
90 t.Fatalf("lease3=(%v,%+v), want blocked without lease", lease3, d)
91 }
92 if err := lease1.Release(ctx); err != nil {
93 t.Fatalf("release lease1: %v", err)
94 }
95 lease4, d, err := l.AcquireLease(ctx, p, "alice")
96 if err != nil {
97 t.Fatalf("lease4 err: %v", err)
98 }
99 if lease4 == nil || !d.Allowed {
100 t.Fatalf("lease4=(%v,%+v), want allowed", lease4, d)
101 }
102 if err := lease1.Release(ctx); err != nil {
103 t.Fatalf("second release lease1: %v", err)
104 }
105 _ = lease2.Release(ctx)
106 _ = lease4.Release(ctx)
107 }
108
109 func TestAcquireLease_RollsStaleWindow(t *testing.T) {
110 t.Parallel()
111 l := New(dbtest.NewTestDB(t))
112 ctx := context.Background()
113 p := Policy{Scope: "test:lease:ttl", Max: 1, Window: time.Millisecond}
114
115 lease1, d, err := l.AcquireLease(ctx, p, "alice")
116 if err != nil {
117 t.Fatalf("lease1 err: %v", err)
118 }
119 if !d.Allowed {
120 t.Fatalf("lease1 blocked: %+v", d)
121 }
122 time.Sleep(5 * time.Millisecond)
123 lease2, d, err := l.AcquireLease(ctx, p, "alice")
124 if err != nil {
125 t.Fatalf("lease2 err: %v", err)
126 }
127 if lease2 == nil || !d.Allowed {
128 t.Fatalf("stale lease did not roll forward: lease=%v decision=%+v", lease2, d)
129 }
130 _ = lease1.Release(ctx)
131 _ = lease2.Release(ctx)
132 }
133
134 func TestAllow_RejectsBadPolicy(t *testing.T) {
135 t.Parallel()
136 l := New(dbtest.NewTestDB(t))
137 ctx := context.Background()
138 cases := []Policy{
139 {Scope: "", Max: 1, Window: time.Hour},
140 {Scope: "x", Max: 0, Window: time.Hour},
141 {Scope: "x", Max: 1, Window: 0},
142 }
143 for _, p := range cases {
144 if _, err := l.Allow(ctx, p, "k"); err == nil {
145 t.Errorf("Allow(%+v) returned nil err; want non-nil", p)
146 }
147 }
148 }
149
150 func TestMaskToNetwork(t *testing.T) {
151 t.Parallel()
152 cases := []struct {
153 in, want string
154 }{
155 {"192.0.2.5", "192.0.2.0"},
156 {"10.0.0.255", "10.0.0.0"},
157 {"203.0.113.42", "203.0.113.0"},
158 }
159 for _, tc := range cases {
160 got := maskToNetwork(netip.MustParseAddr(tc.in))
161 if got.String() != tc.want {
162 t.Errorf("maskToNetwork(%s) = %s; want %s", tc.in, got, tc.want)
163 }
164 }
165 // IPv6 /48
166 v6 := maskToNetwork(netip.MustParseAddr("2001:db8:dead:beef::1"))
167 if want := "2001:db8:dead::"; v6.String() != want {
168 t.Errorf("v6 mask = %s; want %s", v6, want)
169 }
170 }
171
172 func TestAllowSignupIP(t *testing.T) {
173 t.Parallel()
174 l := New(dbtest.NewTestDB(t))
175 ctx := context.Background()
176
177 // Two IPs in the same /24 share a counter.
178 a := netip.MustParseAddr("198.51.100.5")
179 b := netip.MustParseAddr("198.51.100.99")
180 d1, _ := l.AllowSignupIP(ctx, a, 2, time.Hour)
181 d2, _ := l.AllowSignupIP(ctx, b, 2, time.Hour)
182 d3, _ := l.AllowSignupIP(ctx, a, 2, time.Hour)
183 if !d1.Allowed || !d2.Allowed {
184 t.Fatalf("first two hits should be allowed; got %v %v", d1, d2)
185 }
186 if d3.Allowed {
187 t.Fatalf("third hit (same /24) should be blocked; got %v", d3)
188 }
189
190 // A different /24 starts fresh.
191 other := netip.MustParseAddr("203.0.113.7")
192 d4, _ := l.AllowSignupIP(ctx, other, 2, time.Hour)
193 if !d4.Allowed {
194 t.Fatalf("different /24 first hit should be allowed; got %v", d4)
195 }
196 }
197