tenseleyflow/shithub / d1e0592

Browse files

S35: ratelimit — token-bucket Allow + middleware + IP keyer

Authored by espadonne
SHA
d1e0592205f27fa98545c77daa9d941825c582da
Parents
d2d96f3
Tree
435bd61

3 changed files

StatusFile+-
A internal/ratelimit/bucket.go 159 0
A internal/ratelimit/keys.go 61 0
A internal/ratelimit/middleware.go 69 0
internal/ratelimit/bucket.goadded
@@ -0,0 +1,159 @@
1
+// SPDX-License-Identifier: AGPL-3.0-or-later
2
+
3
+// Package ratelimit owns the cross-surface counter used by S35's
4
+// rate-limit middleware. It generalizes S05's auth/throttle pattern
5
+// against the shared rate_limits table so any new surface (API,
6
+// search, git transports) plugs in with a single Allow() call.
7
+//
8
+// Backend: Postgres-backed fixed-window counter via sqlc UPSERT.
9
+// At launch scale this is well within Postgres's comfort zone and
10
+// avoids introducing a Redis dependency. Move only if profiling
11
+// demands it (S36).
12
+package ratelimit
13
+
14
+import (
15
+	"context"
16
+	"errors"
17
+	"fmt"
18
+	"net/netip"
19
+	"time"
20
+
21
+	"github.com/jackc/pgx/v5/pgtype"
22
+	"github.com/jackc/pgx/v5/pgxpool"
23
+
24
+	ratelimitdb "github.com/tenseleyFlow/shithub/internal/ratelimit/sqlc"
25
+)
26
+
27
+// Limiter is the package's primary handle. Construct with New.
28
+type Limiter struct {
29
+	q    *ratelimitdb.Queries
30
+	pool *pgxpool.Pool
31
+}
32
+
33
+// New wires a Limiter against a pool. The pool is required;
34
+// constructing with nil panics so callers fail at boot, not at
35
+// first request.
36
+func New(pool *pgxpool.Pool) *Limiter {
37
+	if pool == nil {
38
+		panic("ratelimit: nil pool")
39
+	}
40
+	return &Limiter{q: ratelimitdb.New(), pool: pool}
41
+}
42
+
43
+// Policy declares a per-(scope, key) limit: at most Max hits within
44
+// the rolling Window.
45
+type Policy struct {
46
+	Scope  string        // e.g. "api:anon", "search", "git:https"
47
+	Max    int           // hits permitted within Window
48
+	Window time.Duration // window length (15s, 1m, 1h, …)
49
+}
50
+
51
+// Decision is the verdict from Allow.
52
+type Decision struct {
53
+	Allowed    bool
54
+	Remaining  int           // hits left in the current window (post-increment)
55
+	Limit      int           // == Policy.Max, surfaced for the X-RateLimit-Limit header
56
+	ResetIn    time.Duration // wall-clock until the current window rolls over
57
+	RetryAfter time.Duration // 0 when Allowed; otherwise the wait the client should respect
58
+}
59
+
60
+// Allow increments the (scope, key) counter and reports whether the
61
+// caller is under or over the configured Max. Returns the post-
62
+// increment Remaining + the time until the current window rolls.
63
+//
64
+// On a Postgres error the request is allowed (fail-open). The caller
65
+// is expected to log the error; refusing service over a transient
66
+// counter glitch would be worse than the brief over-limit window.
67
+func (l *Limiter) Allow(ctx context.Context, p Policy, key string) (Decision, error) {
68
+	if p.Max <= 0 || p.Window <= 0 {
69
+		return Decision{}, errors.New("ratelimit: Policy.Max and Window must be positive")
70
+	}
71
+	if p.Scope == "" || key == "" {
72
+		return Decision{}, errors.New("ratelimit: Policy.Scope and key must be non-empty")
73
+	}
74
+	row, err := l.q.BumpRateLimit(ctx, l.pool, ratelimitdb.BumpRateLimitParams{
75
+		Scope: p.Scope,
76
+		Key:   key,
77
+		Ttl:   pgtype.Interval{Microseconds: int64(p.Window / time.Microsecond), Valid: true},
78
+	})
79
+	if err != nil {
80
+		return Decision{Allowed: true, Remaining: p.Max, Limit: p.Max, ResetIn: p.Window}, fmt.Errorf("ratelimit: bump: %w", err)
81
+	}
82
+
83
+	hits := int(row.Hits)
84
+	resetIn := time.Until(row.WindowStartedAt.Time.Add(p.Window))
85
+	if resetIn < 0 {
86
+		resetIn = 0
87
+	}
88
+	d := Decision{
89
+		Allowed:   hits <= p.Max,
90
+		Limit:     p.Max,
91
+		Remaining: max0(p.Max - hits),
92
+		ResetIn:   resetIn,
93
+	}
94
+	if !d.Allowed {
95
+		d.RetryAfter = resetIn
96
+		if d.RetryAfter <= 0 {
97
+			d.RetryAfter = time.Second
98
+		}
99
+	}
100
+	return d, nil
101
+}
102
+
103
+// AllowSignupIP is the inet-keyed sibling of Allow against the
104
+// signup_ip_throttle table. ip is masked to /24 (v4) or /48 (v6)
105
+// so a single residential allocation shares one counter — matches
106
+// GitHub's approach (per-/24 signup throttle).
107
+func (l *Limiter) AllowSignupIP(ctx context.Context, ip netip.Addr, max int, window time.Duration) (Decision, error) {
108
+	if !ip.IsValid() {
109
+		return Decision{}, errors.New("ratelimit: invalid ip")
110
+	}
111
+	row, err := l.q.BumpSignupIPThrottle(ctx, l.pool, ratelimitdb.BumpSignupIPThrottleParams{
112
+		Cidr: maskToNetwork(ip),
113
+		Ttl:  pgtype.Interval{Microseconds: int64(window / time.Microsecond), Valid: true},
114
+	})
115
+	if err != nil {
116
+		return Decision{Allowed: true, Remaining: max, Limit: max, ResetIn: window}, fmt.Errorf("ratelimit: signup bump: %w", err)
117
+	}
118
+	hits := int(row.Hits)
119
+	resetIn := time.Until(row.WindowStartedAt.Time.Add(window))
120
+	if resetIn < 0 {
121
+		resetIn = 0
122
+	}
123
+	d := Decision{
124
+		Allowed:   hits <= max,
125
+		Limit:     max,
126
+		Remaining: max0(max - hits),
127
+		ResetIn:   resetIn,
128
+	}
129
+	if !d.Allowed {
130
+		d.RetryAfter = resetIn
131
+		if d.RetryAfter <= 0 {
132
+			d.RetryAfter = time.Second
133
+		}
134
+	}
135
+	return d, nil
136
+}
137
+
138
+func max0(n int) int {
139
+	if n < 0 {
140
+		return 0
141
+	}
142
+	return n
143
+}
144
+
145
+// maskToNetwork zeros the host bits of ip so per-/24 (v4) or /48 (v6)
146
+// throttle keys collapse to one network row. The choice of /24 + /48
147
+// matches GitHub's reported anti-abuse defaults.
148
+func maskToNetwork(ip netip.Addr) netip.Addr {
149
+	if ip.Is4() {
150
+		b := ip.As4()
151
+		b[3] = 0
152
+		return netip.AddrFrom4(b)
153
+	}
154
+	b := ip.As16()
155
+	for i := 6; i < 16; i++ { // zero everything past /48
156
+		b[i] = 0
157
+	}
158
+	return netip.AddrFrom16(b)
159
+}
internal/ratelimit/keys.goadded
@@ -0,0 +1,61 @@
1
+// SPDX-License-Identifier: AGPL-3.0-or-later
2
+
3
+package ratelimit
4
+
5
+import (
6
+	"net"
7
+	"net/http"
8
+	"net/netip"
9
+	"strings"
10
+)
11
+
12
+// IPKey is the canonical anonymous-request keyer: extracts the
13
+// client IP from r.RemoteAddr or the X-Forwarded-For header. The
14
+// trust-XFF flag should be set ONLY when the deployment runs
15
+// behind a CDN/proxy we control; otherwise an attacker can spoof
16
+// the header and dodge IP-keyed limits.
17
+func IPKey(trustForwarded bool) KeyFunc {
18
+	return func(r *http.Request) string {
19
+		if trustForwarded {
20
+			if v := r.Header.Get("X-Forwarded-For"); v != "" {
21
+				// First entry is the client; downstream proxies append.
22
+				if comma := strings.IndexByte(v, ','); comma > 0 {
23
+					v = v[:comma]
24
+				}
25
+				return "ip:" + strings.TrimSpace(v)
26
+			}
27
+		}
28
+		host, _, err := net.SplitHostPort(r.RemoteAddr)
29
+		if err != nil {
30
+			host = r.RemoteAddr
31
+		}
32
+		return "ip:" + host
33
+	}
34
+}
35
+
36
+// ClientIP returns the parsed client IP using the same rules as
37
+// IPKey. Used by signup-throttle's CIDR keying.
38
+func ClientIP(r *http.Request, trustForwarded bool) (netip.Addr, bool) {
39
+	raw := ""
40
+	if trustForwarded {
41
+		if v := r.Header.Get("X-Forwarded-For"); v != "" {
42
+			if comma := strings.IndexByte(v, ','); comma > 0 {
43
+				v = v[:comma]
44
+			}
45
+			raw = strings.TrimSpace(v)
46
+		}
47
+	}
48
+	if raw == "" {
49
+		host, _, err := net.SplitHostPort(r.RemoteAddr)
50
+		if err == nil {
51
+			raw = host
52
+		} else {
53
+			raw = r.RemoteAddr
54
+		}
55
+	}
56
+	addr, err := netip.ParseAddr(raw)
57
+	if err != nil {
58
+		return netip.Addr{}, false
59
+	}
60
+	return addr, true
61
+}
internal/ratelimit/middleware.goadded
@@ -0,0 +1,69 @@
1
+// SPDX-License-Identifier: AGPL-3.0-or-later
2
+
3
+package ratelimit
4
+
5
+import (
6
+	"net/http"
7
+	"strconv"
8
+	"time"
9
+)
10
+
11
+// KeyFunc derives the per-request rate-limit key. Common choices:
12
+// remote IP (anon), user id (authed), token id (PAT), repo id+user
13
+// (per-repo), etc. Returning "" skips the limiter (caller should
14
+// only return "" when there's nothing meaningful to throttle on).
15
+type KeyFunc func(*http.Request) string
16
+
17
+// Middleware wraps next with a per-request Allow check. On allow,
18
+// X-RateLimit-* headers are stamped on the response; on deny, the
19
+// response is 429 with Retry-After.
20
+//
21
+// `name` is the per-route identifier surfaced in logs and lets the
22
+// admin observability pages disambiguate scopes that share the same
23
+// counter (e.g. "api:anon" used by both /api/repos and /api/users).
24
+func (l *Limiter) Middleware(p Policy, key KeyFunc) func(http.Handler) http.Handler {
25
+	return func(next http.Handler) http.Handler {
26
+		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
27
+			k := ""
28
+			if key != nil {
29
+				k = key(r)
30
+			}
31
+			if k == "" {
32
+				next.ServeHTTP(w, r)
33
+				return
34
+			}
35
+			d, err := l.Allow(r.Context(), p, k)
36
+			if err != nil {
37
+				// Fail-open on counter errors — surface in logs via
38
+				// the wrapped handler's logging middleware. We still
39
+				// stamp the headers we know.
40
+				StampHeaders(w, d)
41
+				next.ServeHTTP(w, r)
42
+				return
43
+			}
44
+			StampHeaders(w, d)
45
+			if !d.Allowed {
46
+				w.Header().Set("Retry-After", strconv.Itoa(int(d.RetryAfter/time.Second)))
47
+				// Headers MUST be set before WriteHeader for the
48
+				// response Content-Type to land; matches the same
49
+				// pattern as the writeRetryAfter fix from S05.
50
+				w.Header().Set("Content-Type", "text/plain; charset=utf-8")
51
+				w.WriteHeader(http.StatusTooManyRequests)
52
+				_, _ = w.Write([]byte("Rate limit exceeded. Please retry after the X-RateLimit-Reset window.\n"))
53
+				return
54
+			}
55
+			next.ServeHTTP(w, r)
56
+		})
57
+	}
58
+}
59
+
60
+// StampHeaders writes the standard X-RateLimit-* headers from a
61
+// Decision. Exposed so handlers that run a manual Allow can apply
62
+// the same stamp without going through Middleware.
63
+func StampHeaders(w http.ResponseWriter, d Decision) {
64
+	if d.Limit > 0 {
65
+		w.Header().Set("X-RateLimit-Limit", strconv.Itoa(d.Limit))
66
+		w.Header().Set("X-RateLimit-Remaining", strconv.Itoa(d.Remaining))
67
+		w.Header().Set("X-RateLimit-Reset", strconv.Itoa(int(d.ResetIn/time.Second)))
68
+	}
69
+}