| 1 | // SPDX-License-Identifier: AGPL-3.0-or-later |
| 2 | |
| 3 | package middleware |
| 4 | |
| 5 | import ( |
| 6 | "context" |
| 7 | "encoding/base64" |
| 8 | "errors" |
| 9 | "log/slog" |
| 10 | "net/http" |
| 11 | "net/netip" |
| 12 | "strings" |
| 13 | "time" |
| 14 | |
| 15 | "github.com/jackc/pgx/v5/pgxpool" |
| 16 | |
| 17 | "github.com/tenseleyFlow/shithub/internal/auth/pat" |
| 18 | "github.com/tenseleyFlow/shithub/internal/auth/policy" |
| 19 | usersdb "github.com/tenseleyFlow/shithub/internal/users/sqlc" |
| 20 | ) |
| 21 | |
| 22 | var patAuthKey = ctxKey{name: "pat_auth"} |
| 23 | |
| 24 | // PATAuth carries the resolved-token state for downstream handlers. When |
| 25 | // the auth check passed via PAT, `Token != nil` and Scopes is the parsed |
| 26 | // scope list. Pure session callers see the zero value. |
| 27 | type PATAuth struct { |
| 28 | UserID int64 |
| 29 | Username string |
| 30 | TokenID int64 |
| 31 | Scopes []string |
| 32 | IsSuspended bool |
| 33 | IsSiteAdmin bool |
| 34 | } |
| 35 | |
| 36 | // PATAuthFromContext returns the resolved PAT auth state, or the zero |
| 37 | // value when the request was authenticated some other way (or anonymous). |
| 38 | func PATAuthFromContext(ctx context.Context) PATAuth { |
| 39 | if v, ok := ctx.Value(patAuthKey).(PATAuth); ok { |
| 40 | return v |
| 41 | } |
| 42 | return PATAuth{} |
| 43 | } |
| 44 | |
| 45 | // PolicyActor returns the canonical policy actor for a resolved PAT request. |
| 46 | func (p PATAuth) PolicyActor() policy.Actor { |
| 47 | if p.UserID == 0 { |
| 48 | return policy.AnonymousActor() |
| 49 | } |
| 50 | return policy.UserActor(p.UserID, p.Username, p.IsSuspended, p.IsSiteAdmin) |
| 51 | } |
| 52 | |
| 53 | // PATConfig configures the PAT auth middleware. |
| 54 | type PATConfig struct { |
| 55 | Pool *pgxpool.Pool |
| 56 | Debouncer *pat.Debouncer // optional; one is created if nil |
| 57 | Logger *slog.Logger |
| 58 | // Realm is the WWW-Authenticate realm string written on 401. |
| 59 | Realm string |
| 60 | } |
| 61 | |
| 62 | // PATAuthMiddleware returns middleware that resolves a |
| 63 | // `Authorization: token ...`, `Authorization: Bearer ...`, or HTTP Basic |
| 64 | // credential into a populated PATAuth on the request context. |
| 65 | // |
| 66 | // Behavior: |
| 67 | // |
| 68 | // - Missing Authorization header → next handler runs with an empty |
| 69 | // PATAuth (the route may still allow session auth). |
| 70 | // - Malformed credentials, unknown hash, revoked, or expired token → |
| 71 | // 401 with WWW-Authenticate set; chain stops here. |
| 72 | // - On success → PATAuth populated on context, last-used updated |
| 73 | // (debounced), chain proceeds. |
| 74 | func PATAuthMiddleware(cfg PATConfig) func(http.Handler) http.Handler { |
| 75 | if cfg.Debouncer == nil { |
| 76 | cfg.Debouncer = pat.NewDebouncer(60 * time.Second) |
| 77 | } |
| 78 | if cfg.Realm == "" { |
| 79 | cfg.Realm = "shithub" |
| 80 | } |
| 81 | q := usersdb.New() |
| 82 | return func(next http.Handler) http.Handler { |
| 83 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| 84 | raw, err := extractPAT(r) |
| 85 | if errors.Is(err, errNoCredentials) { |
| 86 | next.ServeHTTP(w, r) |
| 87 | return |
| 88 | } |
| 89 | if err != nil { |
| 90 | writePATChallenge(w, cfg.Realm, "invalid token") |
| 91 | return |
| 92 | } |
| 93 | |
| 94 | hash, err := pat.HashOf(raw) |
| 95 | if err != nil { |
| 96 | writePATChallenge(w, cfg.Realm, "invalid token") |
| 97 | return |
| 98 | } |
| 99 | |
| 100 | row, err := q.GetUserTokenByHash(r.Context(), cfg.Pool, hash) |
| 101 | if err != nil { |
| 102 | // pgx.ErrNoRows or any other DB error: respond identically |
| 103 | // so we don't leak hash existence via timing/messaging. |
| 104 | writePATChallenge(w, cfg.Realm, "invalid token") |
| 105 | return |
| 106 | } |
| 107 | if row.RevokedAt.Valid { |
| 108 | writePATChallenge(w, cfg.Realm, "token revoked") |
| 109 | return |
| 110 | } |
| 111 | if row.ExpiresAt.Valid && time.Now().After(row.ExpiresAt.Time) { |
| 112 | writePATChallenge(w, cfg.Realm, "token expired") |
| 113 | return |
| 114 | } |
| 115 | |
| 116 | // Verify owner is not suspended / deleted. The cascade DELETE |
| 117 | // on user removal handles deletion; suspension we check explicitly. |
| 118 | user, err := q.GetUserByID(r.Context(), cfg.Pool, row.UserID) |
| 119 | if err != nil { |
| 120 | writePATChallenge(w, cfg.Realm, "invalid token") |
| 121 | return |
| 122 | } |
| 123 | if user.SuspendedAt.Valid { |
| 124 | writePATChallenge(w, cfg.Realm, "account suspended") |
| 125 | return |
| 126 | } |
| 127 | |
| 128 | // Debounced last-used update — never blocks the request. |
| 129 | // G118: we INTENTIONALLY detach from r.Context() so the |
| 130 | // update survives client disconnect (a debounced touch is |
| 131 | // already a best-effort write; we'd rather complete it than |
| 132 | // drop on cancel). |
| 133 | if cfg.Debouncer.ShouldTouch(row.ID) { |
| 134 | ip := remoteAddrFromRequest(r) |
| 135 | rowID := row.ID |
| 136 | go func() { //nolint:gosec |
| 137 | ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) |
| 138 | defer cancel() |
| 139 | if err := q.TouchUserTokenLastUsed(ctx, cfg.Pool, usersdb.TouchUserTokenLastUsedParams{ |
| 140 | ID: rowID, |
| 141 | LastUsedIp: ip, |
| 142 | }); err != nil && cfg.Logger != nil { |
| 143 | cfg.Logger.WarnContext(ctx, "pat: touch last_used", "error", err) |
| 144 | } |
| 145 | }() |
| 146 | } |
| 147 | |
| 148 | ctx := context.WithValue(r.Context(), patAuthKey, PATAuth{ |
| 149 | UserID: row.UserID, |
| 150 | Username: user.Username, |
| 151 | TokenID: row.ID, |
| 152 | Scopes: row.Scopes, |
| 153 | IsSuspended: user.SuspendedAt.Valid, |
| 154 | IsSiteAdmin: user.IsSiteAdmin, |
| 155 | }) |
| 156 | next.ServeHTTP(w, r.WithContext(ctx)) |
| 157 | }) |
| 158 | } |
| 159 | } |
| 160 | |
| 161 | // RequireScope rejects with 403 if the request was authenticated via PAT |
| 162 | // and the token's scopes don't include required. Pure-session callers |
| 163 | // (PATAuth zero) pass through — sessions have implicit full scope. |
| 164 | func RequireScope(required pat.Scope) func(http.Handler) http.Handler { |
| 165 | return func(next http.Handler) http.Handler { |
| 166 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| 167 | auth := PATAuthFromContext(r.Context()) |
| 168 | if auth.TokenID == 0 { |
| 169 | next.ServeHTTP(w, r) |
| 170 | return |
| 171 | } |
| 172 | if !pat.HasScope(auth.Scopes, required) { |
| 173 | w.Header().Set("Content-Type", "text/plain; charset=utf-8") |
| 174 | w.WriteHeader(http.StatusForbidden) |
| 175 | _, _ = w.Write([]byte("token lacks required scope: " + string(required) + "\n")) |
| 176 | return |
| 177 | } |
| 178 | next.ServeHTTP(w, r) |
| 179 | }) |
| 180 | } |
| 181 | } |
| 182 | |
| 183 | // errNoCredentials is the sentinel that says "no Authorization header at |
| 184 | // all" — distinct from "Authorization present but malformed." |
| 185 | var errNoCredentials = errors.New("middleware: no credentials") |
| 186 | |
| 187 | // extractPAT parses the inbound credential into the raw token string. |
| 188 | // Supports: `Authorization: token <pat>`, `Authorization: Bearer <pat>`, |
| 189 | // and HTTP Basic where the password is the PAT (matches `git`'s |
| 190 | // credential-helper output, used by git-over-HTTPS in S12). |
| 191 | func extractPAT(r *http.Request) (string, error) { |
| 192 | auth := r.Header.Get("Authorization") |
| 193 | if auth == "" { |
| 194 | return "", errNoCredentials |
| 195 | } |
| 196 | scheme, rest, ok := strings.Cut(auth, " ") |
| 197 | if !ok { |
| 198 | return "", errors.New("malformed Authorization header") |
| 199 | } |
| 200 | switch strings.ToLower(scheme) { |
| 201 | case "token", "bearer": |
| 202 | return strings.TrimSpace(rest), nil |
| 203 | case "basic": |
| 204 | decoded, err := base64.StdEncoding.DecodeString(strings.TrimSpace(rest)) |
| 205 | if err != nil { |
| 206 | return "", err |
| 207 | } |
| 208 | _, pass, ok := strings.Cut(string(decoded), ":") |
| 209 | if !ok { |
| 210 | return "", errors.New("malformed Basic credentials") |
| 211 | } |
| 212 | return pass, nil |
| 213 | default: |
| 214 | return "", errors.New("unsupported auth scheme") |
| 215 | } |
| 216 | } |
| 217 | |
| 218 | // writePATChallenge writes the canonical 401 with a Bearer challenge. |
| 219 | func writePATChallenge(w http.ResponseWriter, realm, reason string) { |
| 220 | w.Header().Set("WWW-Authenticate", `Bearer realm="`+realm+`", error="invalid_token", error_description="`+reason+`"`) |
| 221 | w.Header().Set("Content-Type", "text/plain; charset=utf-8") |
| 222 | w.WriteHeader(http.StatusUnauthorized) |
| 223 | _, _ = w.Write([]byte(reason + "\n")) |
| 224 | } |
| 225 | |
| 226 | // remoteAddrFromRequest pulls the client IP for last_used_ip. Reuses the |
| 227 | // request's RealIP if set; otherwise falls back to RemoteAddr's host part. |
| 228 | func remoteAddrFromRequest(r *http.Request) *netip.Addr { |
| 229 | candidates := []string{ |
| 230 | RealIPFromContext(r.Context(), r), |
| 231 | r.RemoteAddr, |
| 232 | } |
| 233 | for _, c := range candidates { |
| 234 | if c == "" { |
| 235 | continue |
| 236 | } |
| 237 | // Strip ":port" suffix if present — RemoteAddr always has one. |
| 238 | host := c |
| 239 | if i := strings.LastIndex(host, ":"); i > 0 && !strings.Contains(host[:i], "]") { |
| 240 | // IPv6 hosts come bracketed [::1]:1234 — those keep the bracket |
| 241 | // and we'd want netip.ParseAddrPort instead. |
| 242 | host = host[:i] |
| 243 | } |
| 244 | host = strings.TrimPrefix(host, "[") |
| 245 | host = strings.TrimSuffix(host, "]") |
| 246 | if addr, err := netip.ParseAddr(host); err == nil { |
| 247 | return &addr |
| 248 | } |
| 249 | } |
| 250 | return nil |
| 251 | } |
| 252 |