Go · 8021 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 package middleware
4
5 import (
6 "context"
7 "encoding/base64"
8 "errors"
9 "log/slog"
10 "net/http"
11 "net/netip"
12 "strings"
13 "time"
14
15 "github.com/jackc/pgx/v5/pgxpool"
16
17 "github.com/tenseleyFlow/shithub/internal/auth/pat"
18 "github.com/tenseleyFlow/shithub/internal/auth/policy"
19 usersdb "github.com/tenseleyFlow/shithub/internal/users/sqlc"
20 )
21
22 var patAuthKey = ctxKey{name: "pat_auth"}
23
24 // PATAuth carries the resolved-token state for downstream handlers. When
25 // the auth check passed via PAT, `Token != nil` and Scopes is the parsed
26 // scope list. Pure session callers see the zero value.
27 type PATAuth struct {
28 UserID int64
29 Username string
30 TokenID int64
31 Scopes []string
32 IsSuspended bool
33 IsSiteAdmin bool
34 }
35
36 // PATAuthFromContext returns the resolved PAT auth state, or the zero
37 // value when the request was authenticated some other way (or anonymous).
38 func PATAuthFromContext(ctx context.Context) PATAuth {
39 if v, ok := ctx.Value(patAuthKey).(PATAuth); ok {
40 return v
41 }
42 return PATAuth{}
43 }
44
45 // PolicyActor returns the canonical policy actor for a resolved PAT request.
46 func (p PATAuth) PolicyActor() policy.Actor {
47 if p.UserID == 0 {
48 return policy.AnonymousActor()
49 }
50 return policy.UserActor(p.UserID, p.Username, p.IsSuspended, p.IsSiteAdmin)
51 }
52
53 // PATConfig configures the PAT auth middleware.
54 type PATConfig struct {
55 Pool *pgxpool.Pool
56 Debouncer *pat.Debouncer // optional; one is created if nil
57 Logger *slog.Logger
58 // Realm is the WWW-Authenticate realm string written on 401.
59 Realm string
60 }
61
62 // PATAuthMiddleware returns middleware that resolves a
63 // `Authorization: token ...`, `Authorization: Bearer ...`, or HTTP Basic
64 // credential into a populated PATAuth on the request context.
65 //
66 // Behavior:
67 //
68 // - Missing Authorization header → next handler runs with an empty
69 // PATAuth (the route may still allow session auth).
70 // - Malformed credentials, unknown hash, revoked, or expired token →
71 // 401 with WWW-Authenticate set; chain stops here.
72 // - On success → PATAuth populated on context, last-used updated
73 // (debounced), chain proceeds.
74 func PATAuthMiddleware(cfg PATConfig) func(http.Handler) http.Handler {
75 if cfg.Debouncer == nil {
76 cfg.Debouncer = pat.NewDebouncer(60 * time.Second)
77 }
78 if cfg.Realm == "" {
79 cfg.Realm = "shithub"
80 }
81 q := usersdb.New()
82 return func(next http.Handler) http.Handler {
83 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
84 raw, err := extractPAT(r)
85 if errors.Is(err, errNoCredentials) {
86 next.ServeHTTP(w, r)
87 return
88 }
89 if err != nil {
90 writePATChallenge(w, cfg.Realm, "invalid token")
91 return
92 }
93
94 hash, err := pat.HashOf(raw)
95 if err != nil {
96 writePATChallenge(w, cfg.Realm, "invalid token")
97 return
98 }
99
100 row, err := q.GetUserTokenByHash(r.Context(), cfg.Pool, hash)
101 if err != nil {
102 // pgx.ErrNoRows or any other DB error: respond identically
103 // so we don't leak hash existence via timing/messaging.
104 writePATChallenge(w, cfg.Realm, "invalid token")
105 return
106 }
107 if row.RevokedAt.Valid {
108 writePATChallenge(w, cfg.Realm, "token revoked")
109 return
110 }
111 if row.ExpiresAt.Valid && time.Now().After(row.ExpiresAt.Time) {
112 writePATChallenge(w, cfg.Realm, "token expired")
113 return
114 }
115
116 // Verify owner is not suspended / deleted. The cascade DELETE
117 // on user removal handles deletion; suspension we check explicitly.
118 user, err := q.GetUserByID(r.Context(), cfg.Pool, row.UserID)
119 if err != nil {
120 writePATChallenge(w, cfg.Realm, "invalid token")
121 return
122 }
123 if user.SuspendedAt.Valid {
124 writePATChallenge(w, cfg.Realm, "account suspended")
125 return
126 }
127
128 // Debounced last-used update — never blocks the request.
129 // G118: we INTENTIONALLY detach from r.Context() so the
130 // update survives client disconnect (a debounced touch is
131 // already a best-effort write; we'd rather complete it than
132 // drop on cancel).
133 if cfg.Debouncer.ShouldTouch(row.ID) {
134 ip := remoteAddrFromRequest(r)
135 rowID := row.ID
136 go func() { //nolint:gosec
137 ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
138 defer cancel()
139 if err := q.TouchUserTokenLastUsed(ctx, cfg.Pool, usersdb.TouchUserTokenLastUsedParams{
140 ID: rowID,
141 LastUsedIp: ip,
142 }); err != nil && cfg.Logger != nil {
143 cfg.Logger.WarnContext(ctx, "pat: touch last_used", "error", err)
144 }
145 }()
146 }
147
148 ctx := context.WithValue(r.Context(), patAuthKey, PATAuth{
149 UserID: row.UserID,
150 Username: user.Username,
151 TokenID: row.ID,
152 Scopes: row.Scopes,
153 IsSuspended: user.SuspendedAt.Valid,
154 IsSiteAdmin: user.IsSiteAdmin,
155 })
156 next.ServeHTTP(w, r.WithContext(ctx))
157 })
158 }
159 }
160
161 // RequireScope rejects with 403 if the request was authenticated via PAT
162 // and the token's scopes don't include required. Pure-session callers
163 // (PATAuth zero) pass through — sessions have implicit full scope.
164 func RequireScope(required pat.Scope) func(http.Handler) http.Handler {
165 return func(next http.Handler) http.Handler {
166 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
167 auth := PATAuthFromContext(r.Context())
168 if auth.TokenID == 0 {
169 next.ServeHTTP(w, r)
170 return
171 }
172 if !pat.HasScope(auth.Scopes, required) {
173 w.Header().Set("Content-Type", "text/plain; charset=utf-8")
174 w.WriteHeader(http.StatusForbidden)
175 _, _ = w.Write([]byte("token lacks required scope: " + string(required) + "\n"))
176 return
177 }
178 next.ServeHTTP(w, r)
179 })
180 }
181 }
182
183 // errNoCredentials is the sentinel that says "no Authorization header at
184 // all" — distinct from "Authorization present but malformed."
185 var errNoCredentials = errors.New("middleware: no credentials")
186
187 // extractPAT parses the inbound credential into the raw token string.
188 // Supports: `Authorization: token <pat>`, `Authorization: Bearer <pat>`,
189 // and HTTP Basic where the password is the PAT (matches `git`'s
190 // credential-helper output, used by git-over-HTTPS in S12).
191 func extractPAT(r *http.Request) (string, error) {
192 auth := r.Header.Get("Authorization")
193 if auth == "" {
194 return "", errNoCredentials
195 }
196 scheme, rest, ok := strings.Cut(auth, " ")
197 if !ok {
198 return "", errors.New("malformed Authorization header")
199 }
200 switch strings.ToLower(scheme) {
201 case "token", "bearer":
202 return strings.TrimSpace(rest), nil
203 case "basic":
204 decoded, err := base64.StdEncoding.DecodeString(strings.TrimSpace(rest))
205 if err != nil {
206 return "", err
207 }
208 _, pass, ok := strings.Cut(string(decoded), ":")
209 if !ok {
210 return "", errors.New("malformed Basic credentials")
211 }
212 return pass, nil
213 default:
214 return "", errors.New("unsupported auth scheme")
215 }
216 }
217
218 // writePATChallenge writes the canonical 401 with a Bearer challenge.
219 func writePATChallenge(w http.ResponseWriter, realm, reason string) {
220 w.Header().Set("WWW-Authenticate", `Bearer realm="`+realm+`", error="invalid_token", error_description="`+reason+`"`)
221 w.Header().Set("Content-Type", "text/plain; charset=utf-8")
222 w.WriteHeader(http.StatusUnauthorized)
223 _, _ = w.Write([]byte(reason + "\n"))
224 }
225
226 // remoteAddrFromRequest pulls the client IP for last_used_ip. Reuses the
227 // request's RealIP if set; otherwise falls back to RemoteAddr's host part.
228 func remoteAddrFromRequest(r *http.Request) *netip.Addr {
229 candidates := []string{
230 RealIPFromContext(r.Context(), r),
231 r.RemoteAddr,
232 }
233 for _, c := range candidates {
234 if c == "" {
235 continue
236 }
237 // Strip ":port" suffix if present — RemoteAddr always has one.
238 host := c
239 if i := strings.LastIndex(host, ":"); i > 0 && !strings.Contains(host[:i], "]") {
240 // IPv6 hosts come bracketed [::1]:1234 — those keep the bracket
241 // and we'd want netip.ParseAddrPort instead.
242 host = host[:i]
243 }
244 host = strings.TrimPrefix(host, "[")
245 host = strings.TrimSuffix(host, "]")
246 if addr, err := netip.ParseAddr(host); err == nil {
247 return &addr
248 }
249 }
250 return nil
251 }
252