Go · 3271 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 package ratelimit
4
5 import (
6 "context"
7 "net/netip"
8 "testing"
9 "time"
10
11 "github.com/tenseleyFlow/shithub/internal/testing/dbtest"
12 )
13
14 func TestAllow_UnderLimitThenBlocked(t *testing.T) {
15 t.Parallel()
16 l := New(dbtest.NewTestDB(t))
17 ctx := context.Background()
18 p := Policy{Scope: "test:underthen", Max: 3, Window: time.Hour}
19 key := "k1"
20
21 for i := 1; i <= 3; i++ {
22 d, err := l.Allow(ctx, p, key)
23 if err != nil {
24 t.Fatalf("hit %d err: %v", i, err)
25 }
26 if !d.Allowed {
27 t.Fatalf("hit %d: expected Allowed; got %v", i, d)
28 }
29 if d.Limit != 3 {
30 t.Errorf("hit %d Limit = %d; want 3", i, d.Limit)
31 }
32 if d.Remaining != 3-i {
33 t.Errorf("hit %d Remaining = %d; want %d", i, d.Remaining, 3-i)
34 }
35 }
36 d, err := l.Allow(ctx, p, key)
37 if err != nil {
38 t.Fatalf("over-limit err: %v", err)
39 }
40 if d.Allowed {
41 t.Fatalf("4th hit allowed; want blocked")
42 }
43 if d.RetryAfter <= 0 {
44 t.Fatalf("RetryAfter = %v; want > 0", d.RetryAfter)
45 }
46 }
47
48 func TestAllow_DistinctKeysAreIndependent(t *testing.T) {
49 t.Parallel()
50 l := New(dbtest.NewTestDB(t))
51 ctx := context.Background()
52 p := Policy{Scope: "test:distinct", Max: 1, Window: time.Hour}
53
54 if d, _ := l.Allow(ctx, p, "alice"); !d.Allowed {
55 t.Fatalf("alice first hit blocked: %v", d)
56 }
57 if d, _ := l.Allow(ctx, p, "bob"); !d.Allowed {
58 t.Fatalf("bob first hit blocked: %v", d)
59 }
60 if d, _ := l.Allow(ctx, p, "alice"); d.Allowed {
61 t.Fatalf("alice second hit allowed; want blocked")
62 }
63 }
64
65 func TestAllow_RejectsBadPolicy(t *testing.T) {
66 t.Parallel()
67 l := New(dbtest.NewTestDB(t))
68 ctx := context.Background()
69 cases := []Policy{
70 {Scope: "", Max: 1, Window: time.Hour},
71 {Scope: "x", Max: 0, Window: time.Hour},
72 {Scope: "x", Max: 1, Window: 0},
73 }
74 for _, p := range cases {
75 if _, err := l.Allow(ctx, p, "k"); err == nil {
76 t.Errorf("Allow(%+v) returned nil err; want non-nil", p)
77 }
78 }
79 }
80
81 func TestMaskToNetwork(t *testing.T) {
82 t.Parallel()
83 cases := []struct {
84 in, want string
85 }{
86 {"192.0.2.5", "192.0.2.0"},
87 {"10.0.0.255", "10.0.0.0"},
88 {"203.0.113.42", "203.0.113.0"},
89 }
90 for _, tc := range cases {
91 got := maskToNetwork(netip.MustParseAddr(tc.in))
92 if got.String() != tc.want {
93 t.Errorf("maskToNetwork(%s) = %s; want %s", tc.in, got, tc.want)
94 }
95 }
96 // IPv6 /48
97 v6 := maskToNetwork(netip.MustParseAddr("2001:db8:dead:beef::1"))
98 if want := "2001:db8:dead::"; v6.String() != want {
99 t.Errorf("v6 mask = %s; want %s", v6, want)
100 }
101 }
102
103 func TestAllowSignupIP(t *testing.T) {
104 t.Parallel()
105 l := New(dbtest.NewTestDB(t))
106 ctx := context.Background()
107
108 // Two IPs in the same /24 share a counter.
109 a := netip.MustParseAddr("198.51.100.5")
110 b := netip.MustParseAddr("198.51.100.99")
111 d1, _ := l.AllowSignupIP(ctx, a, 2, time.Hour)
112 d2, _ := l.AllowSignupIP(ctx, b, 2, time.Hour)
113 d3, _ := l.AllowSignupIP(ctx, a, 2, time.Hour)
114 if !d1.Allowed || !d2.Allowed {
115 t.Fatalf("first two hits should be allowed; got %v %v", d1, d2)
116 }
117 if d3.Allowed {
118 t.Fatalf("third hit (same /24) should be blocked; got %v", d3)
119 }
120
121 // A different /24 starts fresh.
122 other := netip.MustParseAddr("203.0.113.7")
123 d4, _ := l.AllowSignupIP(ctx, other, 2, time.Hour)
124 if !d4.Allowed {
125 t.Fatalf("different /24 first hit should be allowed; got %v", d4)
126 }
127 }
128