Go · 7755 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 // Package ratelimit owns the cross-surface counter used by S35's
4 // rate-limit middleware. It generalizes S05's auth/throttle pattern
5 // against the shared rate_limits table so any new surface (API,
6 // search, git transports) plugs in with a single Allow() call.
7 //
8 // Backend: Postgres-backed fixed-window counter via sqlc UPSERT.
9 // At launch scale this is well within Postgres's comfort zone and
10 // avoids introducing a Redis dependency. Move only if profiling
11 // demands it (S36).
12 package ratelimit
13
14 import (
15 "context"
16 "errors"
17 "fmt"
18 "net/netip"
19 "sync/atomic"
20 "time"
21
22 "github.com/jackc/pgx/v5"
23 "github.com/jackc/pgx/v5/pgtype"
24 "github.com/jackc/pgx/v5/pgxpool"
25
26 ratelimitdb "github.com/tenseleyFlow/shithub/internal/ratelimit/sqlc"
27 )
28
29 // Limiter is the package's primary handle. Construct with New.
30 type Limiter struct {
31 q *ratelimitdb.Queries
32 pool *pgxpool.Pool
33 }
34
35 // New wires a Limiter against a pool. The pool is required;
36 // constructing with nil panics so callers fail at boot, not at
37 // first request.
38 func New(pool *pgxpool.Pool) *Limiter {
39 if pool == nil {
40 panic("ratelimit: nil pool")
41 }
42 return &Limiter{q: ratelimitdb.New(), pool: pool}
43 }
44
45 // Policy declares a per-(scope, key) limit: at most Max hits within
46 // the rolling Window.
47 type Policy struct {
48 Scope string // e.g. "api:anon", "search", "git:https"
49 Max int // hits permitted within Window
50 Window time.Duration // window length (15s, 1m, 1h, …)
51 }
52
53 // Decision is the verdict from Allow.
54 type Decision struct {
55 Allowed bool
56 Remaining int // hits left in the current window (post-increment)
57 Limit int // == Policy.Max, surfaced for the X-RateLimit-Limit header
58 ResetIn time.Duration // wall-clock until the current window rolls over
59 RetryAfter time.Duration // 0 when Allowed; otherwise the wait the client should respect
60 }
61
62 // Lease represents one held concurrent slot. Release is idempotent.
63 type Lease struct {
64 limiter *Limiter
65 policy Policy
66 key string
67 released atomic.Bool
68 }
69
70 // Allow increments the (scope, key) counter and reports whether the
71 // caller is under or over the configured Max. Returns the post-
72 // increment Remaining + the time until the current window rolls.
73 //
74 // On a Postgres error the request is allowed (fail-open). The caller
75 // is expected to log the error; refusing service over a transient
76 // counter glitch would be worse than the brief over-limit window.
77 func (l *Limiter) Allow(ctx context.Context, p Policy, key string) (Decision, error) {
78 if p.Max <= 0 || p.Window <= 0 {
79 return Decision{}, errors.New("ratelimit: Policy.Max and Window must be positive")
80 }
81 if p.Scope == "" || key == "" {
82 return Decision{}, errors.New("ratelimit: Policy.Scope and key must be non-empty")
83 }
84 row, err := l.q.BumpRateLimit(ctx, l.pool, ratelimitdb.BumpRateLimitParams{
85 Scope: p.Scope,
86 Key: key,
87 Ttl: pgtype.Interval{Microseconds: int64(p.Window / time.Microsecond), Valid: true},
88 })
89 if err != nil {
90 return Decision{Allowed: true, Remaining: p.Max, Limit: p.Max, ResetIn: p.Window}, fmt.Errorf("ratelimit: bump: %w", err)
91 }
92
93 hits := int(row.Hits)
94 resetIn := time.Until(row.WindowStartedAt.Time.Add(p.Window))
95 if resetIn < 0 {
96 resetIn = 0
97 }
98 d := Decision{
99 Allowed: hits <= p.Max,
100 Limit: p.Max,
101 Remaining: max0(p.Max - hits),
102 ResetIn: resetIn,
103 }
104 if !d.Allowed {
105 d.RetryAfter = resetIn
106 if d.RetryAfter <= 0 {
107 d.RetryAfter = time.Second
108 }
109 }
110 return d, nil
111 }
112
113 // AcquireLease holds one concurrent slot for long-lived requests such as SSE
114 // streams. Callers must Release when the request exits. Policy.Window is the
115 // stale-lease TTL: it bounds leak duration if a process exits without release.
116 //
117 // Like Allow, transient Postgres errors fail open. The returned lease is nil in
118 // that case, so callers can continue without a release hook while still logging
119 // the error.
120 func (l *Limiter) AcquireLease(ctx context.Context, p Policy, key string) (*Lease, Decision, error) {
121 if p.Max <= 0 || p.Window <= 0 {
122 return nil, Decision{}, errors.New("ratelimit: Policy.Max and Window must be positive")
123 }
124 if p.Scope == "" || key == "" {
125 return nil, Decision{}, errors.New("ratelimit: Policy.Scope and key must be non-empty")
126 }
127 row, err := l.q.AcquireRateLimitLease(ctx, l.pool, ratelimitdb.AcquireRateLimitLeaseParams{
128 Scope: p.Scope,
129 Key: key,
130 Ttl: pgtype.Interval{Microseconds: int64(p.Window / time.Microsecond), Valid: true},
131 MaxHits: int32(p.Max),
132 })
133 if errors.Is(err, pgx.ErrNoRows) {
134 return nil, l.blockedLeaseDecision(ctx, p, key), nil
135 }
136 if err != nil {
137 return nil, Decision{Allowed: true, Remaining: p.Max, Limit: p.Max, ResetIn: p.Window}, fmt.Errorf("ratelimit: acquire lease: %w", err)
138 }
139 hits := int(row.Hits)
140 resetIn := time.Until(row.WindowStartedAt.Time.Add(p.Window))
141 if resetIn < 0 {
142 resetIn = 0
143 }
144 return &Lease{limiter: l, policy: p, key: key}, Decision{
145 Allowed: true,
146 Limit: p.Max,
147 Remaining: max0(p.Max - hits),
148 ResetIn: resetIn,
149 }, nil
150 }
151
152 // Release returns the held slot to the limiter. It is safe to call multiple
153 // times; only the first call touches postgres.
154 func (l *Lease) Release(ctx context.Context) error {
155 if l == nil || l.limiter == nil {
156 return nil
157 }
158 if !l.released.CompareAndSwap(false, true) {
159 return nil
160 }
161 _, err := l.limiter.q.ReleaseRateLimitLease(ctx, l.limiter.pool, ratelimitdb.ReleaseRateLimitLeaseParams{
162 Scope: l.policy.Scope,
163 Key: l.key,
164 })
165 if err != nil {
166 return fmt.Errorf("ratelimit: release lease: %w", err)
167 }
168 return nil
169 }
170
171 func (l *Limiter) blockedLeaseDecision(ctx context.Context, p Policy, key string) Decision {
172 resetIn := p.Window
173 if row, err := l.q.PeekRateLimit(ctx, l.pool, ratelimitdb.PeekRateLimitParams{Scope: p.Scope, Key: key}); err == nil {
174 resetIn = time.Until(row.WindowStartedAt.Time.Add(p.Window))
175 if resetIn <= 0 {
176 resetIn = time.Second
177 }
178 }
179 return Decision{
180 Allowed: false,
181 Limit: p.Max,
182 Remaining: 0,
183 ResetIn: resetIn,
184 RetryAfter: resetIn,
185 }
186 }
187
188 // AllowSignupIP is the inet-keyed sibling of Allow against the
189 // signup_ip_throttle table. ip is masked to /24 (v4) or /48 (v6)
190 // so a single residential allocation shares one counter — matches
191 // GitHub's approach (per-/24 signup throttle).
192 func (l *Limiter) AllowSignupIP(ctx context.Context, ip netip.Addr, max int, window time.Duration) (Decision, error) {
193 if !ip.IsValid() {
194 return Decision{}, errors.New("ratelimit: invalid ip")
195 }
196 row, err := l.q.BumpSignupIPThrottle(ctx, l.pool, ratelimitdb.BumpSignupIPThrottleParams{
197 Cidr: maskToNetwork(ip),
198 Ttl: pgtype.Interval{Microseconds: int64(window / time.Microsecond), Valid: true},
199 })
200 if err != nil {
201 return Decision{Allowed: true, Remaining: max, Limit: max, ResetIn: window}, fmt.Errorf("ratelimit: signup bump: %w", err)
202 }
203 hits := int(row.Hits)
204 resetIn := time.Until(row.WindowStartedAt.Time.Add(window))
205 if resetIn < 0 {
206 resetIn = 0
207 }
208 d := Decision{
209 Allowed: hits <= max,
210 Limit: max,
211 Remaining: max0(max - hits),
212 ResetIn: resetIn,
213 }
214 if !d.Allowed {
215 d.RetryAfter = resetIn
216 if d.RetryAfter <= 0 {
217 d.RetryAfter = time.Second
218 }
219 }
220 return d, nil
221 }
222
223 func max0(n int) int {
224 if n < 0 {
225 return 0
226 }
227 return n
228 }
229
230 // maskToNetwork zeros the host bits of ip so per-/24 (v4) or /48 (v6)
231 // throttle keys collapse to one network row. The choice of /24 + /48
232 // matches GitHub's reported anti-abuse defaults.
233 func maskToNetwork(ip netip.Addr) netip.Addr {
234 if ip.Is4() {
235 b := ip.As4()
236 b[3] = 0
237 return netip.AddrFrom4(b)
238 }
239 b := ip.As16()
240 for i := 6; i < 16; i++ { // zero everything past /48
241 b[i] = 0
242 }
243 return netip.AddrFrom16(b)
244 }
245