tenseleyflow/shithub / 1d477a7

Browse files

api/apilimit: rate-limit + header-stamp middleware for /api/v1/*

Authored by mfwolffe <wolffemf@dukes.jmu.edu>
SHA
1d477a7ea3014843fe6de6a06bbb572ea7125002
Parents
d3a1df9
Tree
4c66551

1 changed file

StatusFile+-
A internal/web/handlers/api/apilimit/apilimit.go 105 0
internal/web/handlers/api/apilimit/apilimit.goadded
@@ -0,0 +1,105 @@
1
+// SPDX-License-Identifier: AGPL-3.0-or-later
2
+
3
+// Package apilimit is the per-request rate limiter that fronts /api/v1/.
4
+// It does two things on every request:
5
+//
6
+//  1. Buckets the caller: PAT-authenticated requests are keyed by token
7
+//     id (scope "api:authed"); anonymous requests are keyed by remote
8
+//     IP (scope "api:anon"). Budgets come from cfg.RateLimit.API.
9
+//  2. Stamps X-RateLimit-Limit / X-RateLimit-Remaining / X-RateLimit-Reset
10
+//     on the response — even on success. The shithub-cli HTTP client
11
+//     parses these headers on every response to surface back-off hints
12
+//     (shithub-cli/internal/api/errors.go).
13
+//
14
+// On deny, the response is the canonical /api/v1 JSON error envelope
15
+// `{"error": "rate limit exceeded"}` with Retry-After set. Postgres
16
+// errors fail open (ratelimit.Allow's documented behavior); the request
17
+// proceeds with whatever decision we have and a warn-level log line.
18
+package apilimit
19
+
20
+import (
21
+	"log/slog"
22
+	"net/http"
23
+	"strconv"
24
+	"time"
25
+
26
+	"github.com/tenseleyFlow/shithub/internal/ratelimit"
27
+	"github.com/tenseleyFlow/shithub/internal/web/middleware"
28
+)
29
+
30
+// Config is the per-instance configuration for the middleware. Both
31
+// budgets are required to be positive at construction.
32
+type Config struct {
33
+	// AuthedPerHour is the bucket size for PAT-authenticated callers.
34
+	AuthedPerHour int
35
+	// AnonPerHour is the bucket size for unauthenticated callers.
36
+	AnonPerHour int
37
+	// Logger receives warn-level lines when the backing counter errors.
38
+	// nil disables logging.
39
+	Logger *slog.Logger
40
+}
41
+
42
+// Middleware returns a chi-compatible middleware that applies the
43
+// configured budgets and stamps the standard X-RateLimit-* headers.
44
+// When l is nil the middleware is a no-op (used by tests that don't
45
+// stand up the ratelimit DB).
46
+func Middleware(l *ratelimit.Limiter, cfg Config) func(http.Handler) http.Handler {
47
+	authedPolicy := ratelimit.Policy{
48
+		Scope:  "api:authed",
49
+		Max:    cfg.AuthedPerHour,
50
+		Window: time.Hour,
51
+	}
52
+	anonPolicy := ratelimit.Policy{
53
+		Scope:  "api:anon",
54
+		Max:    cfg.AnonPerHour,
55
+		Window: time.Hour,
56
+	}
57
+	logger := cfg.Logger
58
+	return func(next http.Handler) http.Handler {
59
+		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
60
+			if l == nil {
61
+				next.ServeHTTP(w, r)
62
+				return
63
+			}
64
+			policy, key := pickBucket(r, authedPolicy, anonPolicy)
65
+			if policy.Max <= 0 || key == "" {
66
+				// Misconfigured budget or no key derivable — fail open
67
+				// rather than refuse service. The boot-time validation
68
+				// in config.Validate keeps this branch unreachable in
69
+				// practice.
70
+				next.ServeHTTP(w, r)
71
+				return
72
+			}
73
+			decision, err := l.Allow(r.Context(), policy, key)
74
+			if err != nil && logger != nil {
75
+				logger.WarnContext(r.Context(), "apilimit: counter error", "scope", policy.Scope, "key", key, "error", err)
76
+			}
77
+			ratelimit.StampHeaders(w, decision)
78
+			if !decision.Allowed {
79
+				retry := int(decision.RetryAfter / time.Second)
80
+				if retry < 1 {
81
+					retry = 1
82
+				}
83
+				w.Header().Set("Retry-After", strconv.Itoa(retry))
84
+				w.Header().Set("Content-Type", "application/json; charset=utf-8")
85
+				w.Header().Set("Cache-Control", "no-store")
86
+				w.WriteHeader(http.StatusTooManyRequests)
87
+				_, _ = w.Write([]byte(`{"error":"rate limit exceeded"}` + "\n"))
88
+				return
89
+			}
90
+			next.ServeHTTP(w, r)
91
+		})
92
+	}
93
+}
94
+
95
+func pickBucket(r *http.Request, authed, anon ratelimit.Policy) (ratelimit.Policy, string) {
96
+	auth := middleware.PATAuthFromContext(r.Context())
97
+	if auth.TokenID != 0 {
98
+		return authed, "pat:" + strconv.FormatInt(auth.TokenID, 10)
99
+	}
100
+	ip := middleware.RealIPFromContext(r.Context(), r)
101
+	if ip == "" {
102
+		return anon, ""
103
+	}
104
+	return anon, "ip:" + ip
105
+}