S35: ratelimit — token-bucket Allow + middleware + IP keyer
- SHA
d1e0592205f27fa98545c77daa9d941825c582da- Parents
-
d2d96f3 - Tree
435bd61
d1e0592
d1e0592205f27fa98545c77daa9d941825c582dad2d96f3
435bd61| Status | File | + | - |
|---|---|---|---|
| A |
internal/ratelimit/bucket.go
|
159 | 0 |
| A |
internal/ratelimit/keys.go
|
61 | 0 |
| A |
internal/ratelimit/middleware.go
|
69 | 0 |
internal/ratelimit/bucket.goadded@@ -0,0 +1,159 @@ | ||
| 1 | +// SPDX-License-Identifier: AGPL-3.0-or-later | |
| 2 | + | |
| 3 | +// Package ratelimit owns the cross-surface counter used by S35's | |
| 4 | +// rate-limit middleware. It generalizes S05's auth/throttle pattern | |
| 5 | +// against the shared rate_limits table so any new surface (API, | |
| 6 | +// search, git transports) plugs in with a single Allow() call. | |
| 7 | +// | |
| 8 | +// Backend: Postgres-backed fixed-window counter via sqlc UPSERT. | |
| 9 | +// At launch scale this is well within Postgres's comfort zone and | |
| 10 | +// avoids introducing a Redis dependency. Move only if profiling | |
| 11 | +// demands it (S36). | |
| 12 | +package ratelimit | |
| 13 | + | |
| 14 | +import ( | |
| 15 | + "context" | |
| 16 | + "errors" | |
| 17 | + "fmt" | |
| 18 | + "net/netip" | |
| 19 | + "time" | |
| 20 | + | |
| 21 | + "github.com/jackc/pgx/v5/pgtype" | |
| 22 | + "github.com/jackc/pgx/v5/pgxpool" | |
| 23 | + | |
| 24 | + ratelimitdb "github.com/tenseleyFlow/shithub/internal/ratelimit/sqlc" | |
| 25 | +) | |
| 26 | + | |
| 27 | +// Limiter is the package's primary handle. Construct with New. | |
| 28 | +type Limiter struct { | |
| 29 | + q *ratelimitdb.Queries | |
| 30 | + pool *pgxpool.Pool | |
| 31 | +} | |
| 32 | + | |
| 33 | +// New wires a Limiter against a pool. The pool is required; | |
| 34 | +// constructing with nil panics so callers fail at boot, not at | |
| 35 | +// first request. | |
| 36 | +func New(pool *pgxpool.Pool) *Limiter { | |
| 37 | + if pool == nil { | |
| 38 | + panic("ratelimit: nil pool") | |
| 39 | + } | |
| 40 | + return &Limiter{q: ratelimitdb.New(), pool: pool} | |
| 41 | +} | |
| 42 | + | |
| 43 | +// Policy declares a per-(scope, key) limit: at most Max hits within | |
| 44 | +// the rolling Window. | |
| 45 | +type Policy struct { | |
| 46 | + Scope string // e.g. "api:anon", "search", "git:https" | |
| 47 | + Max int // hits permitted within Window | |
| 48 | + Window time.Duration // window length (15s, 1m, 1h, …) | |
| 49 | +} | |
| 50 | + | |
| 51 | +// Decision is the verdict from Allow. | |
| 52 | +type Decision struct { | |
| 53 | + Allowed bool | |
| 54 | + Remaining int // hits left in the current window (post-increment) | |
| 55 | + Limit int // == Policy.Max, surfaced for the X-RateLimit-Limit header | |
| 56 | + ResetIn time.Duration // wall-clock until the current window rolls over | |
| 57 | + RetryAfter time.Duration // 0 when Allowed; otherwise the wait the client should respect | |
| 58 | +} | |
| 59 | + | |
| 60 | +// Allow increments the (scope, key) counter and reports whether the | |
| 61 | +// caller is under or over the configured Max. Returns the post- | |
| 62 | +// increment Remaining + the time until the current window rolls. | |
| 63 | +// | |
| 64 | +// On a Postgres error the request is allowed (fail-open). The caller | |
| 65 | +// is expected to log the error; refusing service over a transient | |
| 66 | +// counter glitch would be worse than the brief over-limit window. | |
| 67 | +func (l *Limiter) Allow(ctx context.Context, p Policy, key string) (Decision, error) { | |
| 68 | + if p.Max <= 0 || p.Window <= 0 { | |
| 69 | + return Decision{}, errors.New("ratelimit: Policy.Max and Window must be positive") | |
| 70 | + } | |
| 71 | + if p.Scope == "" || key == "" { | |
| 72 | + return Decision{}, errors.New("ratelimit: Policy.Scope and key must be non-empty") | |
| 73 | + } | |
| 74 | + row, err := l.q.BumpRateLimit(ctx, l.pool, ratelimitdb.BumpRateLimitParams{ | |
| 75 | + Scope: p.Scope, | |
| 76 | + Key: key, | |
| 77 | + Ttl: pgtype.Interval{Microseconds: int64(p.Window / time.Microsecond), Valid: true}, | |
| 78 | + }) | |
| 79 | + if err != nil { | |
| 80 | + return Decision{Allowed: true, Remaining: p.Max, Limit: p.Max, ResetIn: p.Window}, fmt.Errorf("ratelimit: bump: %w", err) | |
| 81 | + } | |
| 82 | + | |
| 83 | + hits := int(row.Hits) | |
| 84 | + resetIn := time.Until(row.WindowStartedAt.Time.Add(p.Window)) | |
| 85 | + if resetIn < 0 { | |
| 86 | + resetIn = 0 | |
| 87 | + } | |
| 88 | + d := Decision{ | |
| 89 | + Allowed: hits <= p.Max, | |
| 90 | + Limit: p.Max, | |
| 91 | + Remaining: max0(p.Max - hits), | |
| 92 | + ResetIn: resetIn, | |
| 93 | + } | |
| 94 | + if !d.Allowed { | |
| 95 | + d.RetryAfter = resetIn | |
| 96 | + if d.RetryAfter <= 0 { | |
| 97 | + d.RetryAfter = time.Second | |
| 98 | + } | |
| 99 | + } | |
| 100 | + return d, nil | |
| 101 | +} | |
| 102 | + | |
| 103 | +// AllowSignupIP is the inet-keyed sibling of Allow against the | |
| 104 | +// signup_ip_throttle table. ip is masked to /24 (v4) or /48 (v6) | |
| 105 | +// so a single residential allocation shares one counter — matches | |
| 106 | +// GitHub's approach (per-/24 signup throttle). | |
| 107 | +func (l *Limiter) AllowSignupIP(ctx context.Context, ip netip.Addr, max int, window time.Duration) (Decision, error) { | |
| 108 | + if !ip.IsValid() { | |
| 109 | + return Decision{}, errors.New("ratelimit: invalid ip") | |
| 110 | + } | |
| 111 | + row, err := l.q.BumpSignupIPThrottle(ctx, l.pool, ratelimitdb.BumpSignupIPThrottleParams{ | |
| 112 | + Cidr: maskToNetwork(ip), | |
| 113 | + Ttl: pgtype.Interval{Microseconds: int64(window / time.Microsecond), Valid: true}, | |
| 114 | + }) | |
| 115 | + if err != nil { | |
| 116 | + return Decision{Allowed: true, Remaining: max, Limit: max, ResetIn: window}, fmt.Errorf("ratelimit: signup bump: %w", err) | |
| 117 | + } | |
| 118 | + hits := int(row.Hits) | |
| 119 | + resetIn := time.Until(row.WindowStartedAt.Time.Add(window)) | |
| 120 | + if resetIn < 0 { | |
| 121 | + resetIn = 0 | |
| 122 | + } | |
| 123 | + d := Decision{ | |
| 124 | + Allowed: hits <= max, | |
| 125 | + Limit: max, | |
| 126 | + Remaining: max0(max - hits), | |
| 127 | + ResetIn: resetIn, | |
| 128 | + } | |
| 129 | + if !d.Allowed { | |
| 130 | + d.RetryAfter = resetIn | |
| 131 | + if d.RetryAfter <= 0 { | |
| 132 | + d.RetryAfter = time.Second | |
| 133 | + } | |
| 134 | + } | |
| 135 | + return d, nil | |
| 136 | +} | |
| 137 | + | |
| 138 | +func max0(n int) int { | |
| 139 | + if n < 0 { | |
| 140 | + return 0 | |
| 141 | + } | |
| 142 | + return n | |
| 143 | +} | |
| 144 | + | |
| 145 | +// maskToNetwork zeros the host bits of ip so per-/24 (v4) or /48 (v6) | |
| 146 | +// throttle keys collapse to one network row. The choice of /24 + /48 | |
| 147 | +// matches GitHub's reported anti-abuse defaults. | |
| 148 | +func maskToNetwork(ip netip.Addr) netip.Addr { | |
| 149 | + if ip.Is4() { | |
| 150 | + b := ip.As4() | |
| 151 | + b[3] = 0 | |
| 152 | + return netip.AddrFrom4(b) | |
| 153 | + } | |
| 154 | + b := ip.As16() | |
| 155 | + for i := 6; i < 16; i++ { // zero everything past /48 | |
| 156 | + b[i] = 0 | |
| 157 | + } | |
| 158 | + return netip.AddrFrom16(b) | |
| 159 | +} | |
internal/ratelimit/keys.goadded@@ -0,0 +1,61 @@ | ||
| 1 | +// SPDX-License-Identifier: AGPL-3.0-or-later | |
| 2 | + | |
| 3 | +package ratelimit | |
| 4 | + | |
| 5 | +import ( | |
| 6 | + "net" | |
| 7 | + "net/http" | |
| 8 | + "net/netip" | |
| 9 | + "strings" | |
| 10 | +) | |
| 11 | + | |
| 12 | +// IPKey is the canonical anonymous-request keyer: extracts the | |
| 13 | +// client IP from r.RemoteAddr or the X-Forwarded-For header. The | |
| 14 | +// trust-XFF flag should be set ONLY when the deployment runs | |
| 15 | +// behind a CDN/proxy we control; otherwise an attacker can spoof | |
| 16 | +// the header and dodge IP-keyed limits. | |
| 17 | +func IPKey(trustForwarded bool) KeyFunc { | |
| 18 | + return func(r *http.Request) string { | |
| 19 | + if trustForwarded { | |
| 20 | + if v := r.Header.Get("X-Forwarded-For"); v != "" { | |
| 21 | + // First entry is the client; downstream proxies append. | |
| 22 | + if comma := strings.IndexByte(v, ','); comma > 0 { | |
| 23 | + v = v[:comma] | |
| 24 | + } | |
| 25 | + return "ip:" + strings.TrimSpace(v) | |
| 26 | + } | |
| 27 | + } | |
| 28 | + host, _, err := net.SplitHostPort(r.RemoteAddr) | |
| 29 | + if err != nil { | |
| 30 | + host = r.RemoteAddr | |
| 31 | + } | |
| 32 | + return "ip:" + host | |
| 33 | + } | |
| 34 | +} | |
| 35 | + | |
| 36 | +// ClientIP returns the parsed client IP using the same rules as | |
| 37 | +// IPKey. Used by signup-throttle's CIDR keying. | |
| 38 | +func ClientIP(r *http.Request, trustForwarded bool) (netip.Addr, bool) { | |
| 39 | + raw := "" | |
| 40 | + if trustForwarded { | |
| 41 | + if v := r.Header.Get("X-Forwarded-For"); v != "" { | |
| 42 | + if comma := strings.IndexByte(v, ','); comma > 0 { | |
| 43 | + v = v[:comma] | |
| 44 | + } | |
| 45 | + raw = strings.TrimSpace(v) | |
| 46 | + } | |
| 47 | + } | |
| 48 | + if raw == "" { | |
| 49 | + host, _, err := net.SplitHostPort(r.RemoteAddr) | |
| 50 | + if err == nil { | |
| 51 | + raw = host | |
| 52 | + } else { | |
| 53 | + raw = r.RemoteAddr | |
| 54 | + } | |
| 55 | + } | |
| 56 | + addr, err := netip.ParseAddr(raw) | |
| 57 | + if err != nil { | |
| 58 | + return netip.Addr{}, false | |
| 59 | + } | |
| 60 | + return addr, true | |
| 61 | +} | |
internal/ratelimit/middleware.goadded@@ -0,0 +1,69 @@ | ||
| 1 | +// SPDX-License-Identifier: AGPL-3.0-or-later | |
| 2 | + | |
| 3 | +package ratelimit | |
| 4 | + | |
| 5 | +import ( | |
| 6 | + "net/http" | |
| 7 | + "strconv" | |
| 8 | + "time" | |
| 9 | +) | |
| 10 | + | |
| 11 | +// KeyFunc derives the per-request rate-limit key. Common choices: | |
| 12 | +// remote IP (anon), user id (authed), token id (PAT), repo id+user | |
| 13 | +// (per-repo), etc. Returning "" skips the limiter (caller should | |
| 14 | +// only return "" when there's nothing meaningful to throttle on). | |
| 15 | +type KeyFunc func(*http.Request) string | |
| 16 | + | |
| 17 | +// Middleware wraps next with a per-request Allow check. On allow, | |
| 18 | +// X-RateLimit-* headers are stamped on the response; on deny, the | |
| 19 | +// response is 429 with Retry-After. | |
| 20 | +// | |
| 21 | +// `name` is the per-route identifier surfaced in logs and lets the | |
| 22 | +// admin observability pages disambiguate scopes that share the same | |
| 23 | +// counter (e.g. "api:anon" used by both /api/repos and /api/users). | |
| 24 | +func (l *Limiter) Middleware(p Policy, key KeyFunc) func(http.Handler) http.Handler { | |
| 25 | + return func(next http.Handler) http.Handler { | |
| 26 | + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |
| 27 | + k := "" | |
| 28 | + if key != nil { | |
| 29 | + k = key(r) | |
| 30 | + } | |
| 31 | + if k == "" { | |
| 32 | + next.ServeHTTP(w, r) | |
| 33 | + return | |
| 34 | + } | |
| 35 | + d, err := l.Allow(r.Context(), p, k) | |
| 36 | + if err != nil { | |
| 37 | + // Fail-open on counter errors — surface in logs via | |
| 38 | + // the wrapped handler's logging middleware. We still | |
| 39 | + // stamp the headers we know. | |
| 40 | + StampHeaders(w, d) | |
| 41 | + next.ServeHTTP(w, r) | |
| 42 | + return | |
| 43 | + } | |
| 44 | + StampHeaders(w, d) | |
| 45 | + if !d.Allowed { | |
| 46 | + w.Header().Set("Retry-After", strconv.Itoa(int(d.RetryAfter/time.Second))) | |
| 47 | + // Headers MUST be set before WriteHeader for the | |
| 48 | + // response Content-Type to land; matches the same | |
| 49 | + // pattern as the writeRetryAfter fix from S05. | |
| 50 | + w.Header().Set("Content-Type", "text/plain; charset=utf-8") | |
| 51 | + w.WriteHeader(http.StatusTooManyRequests) | |
| 52 | + _, _ = w.Write([]byte("Rate limit exceeded. Please retry after the X-RateLimit-Reset window.\n")) | |
| 53 | + return | |
| 54 | + } | |
| 55 | + next.ServeHTTP(w, r) | |
| 56 | + }) | |
| 57 | + } | |
| 58 | +} | |
| 59 | + | |
| 60 | +// StampHeaders writes the standard X-RateLimit-* headers from a | |
| 61 | +// Decision. Exposed so handlers that run a manual Allow can apply | |
| 62 | +// the same stamp without going through Middleware. | |
| 63 | +func StampHeaders(w http.ResponseWriter, d Decision) { | |
| 64 | + if d.Limit > 0 { | |
| 65 | + w.Header().Set("X-RateLimit-Limit", strconv.Itoa(d.Limit)) | |
| 66 | + w.Header().Set("X-RateLimit-Remaining", strconv.Itoa(d.Remaining)) | |
| 67 | + w.Header().Set("X-RateLimit-Reset", strconv.Itoa(int(d.ResetIn/time.Second))) | |
| 68 | + } | |
| 69 | +} | |