Go · 17517 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 package auth
4
5 import (
6 "context"
7 "errors"
8 "fmt"
9 "net/http"
10 "strings"
11 "time"
12
13 "github.com/jackc/pgx/v5"
14 "github.com/jackc/pgx/v5/pgtype"
15
16 "github.com/tenseleyFlow/shithub/internal/auth/audit"
17 "github.com/tenseleyFlow/shithub/internal/auth/email"
18 "github.com/tenseleyFlow/shithub/internal/auth/password"
19 "github.com/tenseleyFlow/shithub/internal/auth/throttle"
20 "github.com/tenseleyFlow/shithub/internal/auth/totp"
21 usersdb "github.com/tenseleyFlow/shithub/internal/users/sqlc"
22 "github.com/tenseleyFlow/shithub/internal/web/middleware"
23 )
24
25 // ============================ login challenge ===========================
26
27 func (h *Handlers) twoFactorChallengeForm(w http.ResponseWriter, r *http.Request) {
28 s := middleware.SessionFromContext(r.Context())
29 if s.Pre2FAUserID == 0 {
30 http.Redirect(w, r, "/login", http.StatusSeeOther)
31 return
32 }
33 h.renderPage(w, r, "auth/2fa_challenge", map[string]any{
34 "Title": "Two-factor authentication",
35 "CSRFToken": middleware.CSRFTokenForRequest(r),
36 "Next": r.URL.Query().Get("next"),
37 })
38 }
39
40 func (h *Handlers) twoFactorChallengeSubmit(w http.ResponseWriter, r *http.Request) {
41 if err := r.ParseForm(); err != nil {
42 h.d.Render.HTTPError(w, r, http.StatusBadRequest, "form parse")
43 return
44 }
45 s := middleware.SessionFromContext(r.Context())
46 if s.Pre2FAUserID == 0 {
47 http.Redirect(w, r, "/login", http.StatusSeeOther)
48 return
49 }
50 userID := s.Pre2FAUserID
51 code := strings.TrimSpace(r.PostFormValue("code"))
52 next := r.PostFormValue("next")
53
54 render := func(msg string) {
55 h.renderPage(w, r, "auth/2fa_challenge", map[string]any{
56 "Title": "Two-factor authentication",
57 "CSRFToken": middleware.CSRFTokenForRequest(r),
58 "Error": msg,
59 "Next": next,
60 })
61 }
62
63 throttleKey := fmt.Sprintf("ip:%s|uid:%d", clientIP(r), userID)
64 if err := h.d.Limiter.Hit(r.Context(), h.d.Pool, throttle.Limit{
65 Scope: "2fa", Identifier: throttleKey,
66 Max: 5, Window: 5 * time.Minute,
67 }); err != nil {
68 h.writeRetryAfter(w, err)
69 render("Too many failed attempts. Please sign in again.")
70 // Drop pre-2fa marker so caller restarts the flow.
71 s.Pre2FAUserID = 0
72 _ = h.d.SessionStore.Save(w, r, s)
73 return
74 }
75
76 if code == "" {
77 render("Enter your 6-digit code or a recovery code.")
78 return
79 }
80
81 accepted := false
82 usedRecovery := false
83 if totp.LooksLikeRecoveryCode(code) {
84 ok, err := h.consumeRecoveryCode(r.Context(), userID, code)
85 if err != nil {
86 h.d.Logger.ErrorContext(r.Context(), "2fa: consume recovery", "error", err)
87 h.d.Render.HTTPError(w, r, http.StatusInternalServerError, "")
88 return
89 }
90 accepted = ok
91 usedRecovery = ok
92 } else {
93 ok, err := h.verifyTOTPCode(r.Context(), userID, code)
94 if err != nil {
95 h.d.Logger.ErrorContext(r.Context(), "2fa: verify totp", "error", err)
96 h.d.Render.HTTPError(w, r, http.StatusInternalServerError, "")
97 return
98 }
99 accepted = ok
100 }
101
102 if !accepted {
103 render("Incorrect code. Try again.")
104 return
105 }
106
107 // Forgive prior failed-attempt counter on success.
108 _ = h.d.Limiter.Reset(r.Context(), h.d.Pool, "2fa", throttleKey)
109
110 if err := h.q.TouchUserLastLogin(r.Context(), h.d.Pool, userID); err != nil {
111 h.d.Logger.WarnContext(r.Context(), "2fa: touch last_login_at", "error", err)
112 }
113
114 if usedRecovery {
115 _ = h.d.Audit.Record(r.Context(), h.d.Pool, userID,
116 audit.ActionRecoveryCodeUsed, audit.TargetUser, userID, nil)
117 }
118
119 // Upgrade session: drop pre-2FA marker, set UserID, reissue.
120 // Recent2FAAt timestamps the just-completed challenge so the recent-
121 // auth gate (PAT creation, etc.) can verify a fresh second factor.
122 // Epoch snapshot powers "log out everywhere".
123 epoch, err := h.q.GetUserSessionEpoch(r.Context(), h.d.Pool, userID)
124 if err != nil {
125 h.d.Logger.ErrorContext(r.Context(), "2fa: load epoch", "error", err)
126 h.d.Render.HTTPError(w, r, http.StatusInternalServerError, "")
127 return
128 }
129 s.Pre2FAUserID = 0
130 s.UserID = userID
131 s.Epoch = epoch
132 s.Recent2FAAt = time.Now().Unix()
133 s.IssuedAt = time.Now().Unix()
134 if err := h.d.SessionStore.Save(w, r, s); err != nil {
135 h.d.Logger.ErrorContext(r.Context(), "2fa: save session", "error", err)
136 h.d.Render.HTTPError(w, r, http.StatusInternalServerError, "")
137 return
138 }
139
140 dest := "/"
141 if next != "" && strings.HasPrefix(next, "/") && !strings.HasPrefix(next, "//") {
142 dest = next
143 }
144 //nolint:gosec // G710: dest is whitelisted to single-leading-slash relative paths.
145 http.Redirect(w, r, dest, http.StatusSeeOther)
146 }
147
148 // ============================ enrollment ================================
149
150 func (h *Handlers) twoFactorEnableForm(w http.ResponseWriter, r *http.Request) {
151 user := middleware.CurrentUserFromContext(r.Context())
152
153 // If already enrolled and confirmed, send to disable page instead.
154 if existing, err := h.q.GetUserTOTP(r.Context(), h.d.Pool, user.ID); err == nil && existing.ConfirmedAt.Valid {
155 http.Redirect(w, r, "/settings/security/2fa/disable", http.StatusSeeOther)
156 return
157 }
158
159 // Mint or replace a pending secret. UpsertUserTOTP only updates when
160 // confirmed_at IS NULL — confirmed rows are protected.
161 secret, err := totp.GenerateSecret()
162 if err != nil {
163 h.d.Logger.ErrorContext(r.Context(), "2fa: secret", "error", err)
164 h.d.Render.HTTPError(w, r, http.StatusInternalServerError, "")
165 return
166 }
167 enc, nonce, err := h.d.SecretBox.Seal(secret)
168 if err != nil {
169 h.d.Logger.ErrorContext(r.Context(), "2fa: seal", "error", err)
170 h.d.Render.HTTPError(w, r, http.StatusInternalServerError, "")
171 return
172 }
173 if _, err := h.q.UpsertUserTOTP(r.Context(), h.d.Pool, usersdb.UpsertUserTOTPParams{
174 UserID: user.ID,
175 SecretEncrypted: enc,
176 SecretNonce: nonce,
177 }); err != nil {
178 h.d.Logger.ErrorContext(r.Context(), "2fa: upsert secret", "error", err)
179 h.d.Render.HTTPError(w, r, http.StatusInternalServerError, "")
180 return
181 }
182
183 uri := totp.OtpauthURI(h.d.Branding.SiteName, user.Username, secret)
184 svg, err := totp.QRSVG(uri)
185 if err != nil {
186 h.d.Logger.ErrorContext(r.Context(), "2fa: qr", "error", err)
187 h.d.Render.HTTPError(w, r, http.StatusInternalServerError, "")
188 return
189 }
190
191 h.renderPage(w, r, "settings/2fa_enable", map[string]any{
192 "Title": "Enable two-factor authentication",
193 "CSRFToken": middleware.CSRFTokenForRequest(r),
194 "SettingsActive": "2fa",
195 "QRSvg": svg,
196 "Secret": totp.EncodeBase32(secret), // displayed for manual entry; also high-entropy + redacted in logs
197 })
198 }
199
200 func (h *Handlers) twoFactorEnableSubmit(w http.ResponseWriter, r *http.Request) {
201 if err := r.ParseForm(); err != nil {
202 h.d.Render.HTTPError(w, r, http.StatusBadRequest, "form parse")
203 return
204 }
205 user := middleware.CurrentUserFromContext(r.Context())
206 code := strings.TrimSpace(r.PostFormValue("code"))
207
208 render := func(msg string, recoveryCodes []string) {
209 data := map[string]any{
210 "Title": "Enable two-factor authentication",
211 "CSRFToken": middleware.CSRFTokenForRequest(r),
212 "SettingsActive": "2fa",
213 }
214 if msg != "" {
215 data["Error"] = msg
216 }
217 if len(recoveryCodes) > 0 {
218 data["RecoveryCodes"] = recoveryCodes
219 }
220 h.renderPage(w, r, "settings/2fa_recovery", data)
221 }
222
223 row, err := h.q.GetUserTOTP(r.Context(), h.d.Pool, user.ID)
224 if err != nil {
225 h.d.Render.HTTPError(w, r, http.StatusBadRequest, "no pending 2FA enrollment")
226 return
227 }
228 if row.ConfirmedAt.Valid {
229 http.Redirect(w, r, "/settings/security/2fa/disable", http.StatusSeeOther)
230 return
231 }
232
233 secret, err := h.d.SecretBox.Open(row.SecretEncrypted, row.SecretNonce)
234 if err != nil {
235 h.d.Logger.ErrorContext(r.Context(), "2fa: open secret", "error", err)
236 h.d.Render.HTTPError(w, r, http.StatusInternalServerError, "")
237 return
238 }
239 step, err := totp.Verify(secret, code, time.Now())
240 if err != nil {
241 render("That code is incorrect. Try again.", nil)
242 return
243 }
244
245 codes, hashes, err := totp.GenerateRecoveryCodes()
246 if err != nil {
247 h.d.Logger.ErrorContext(r.Context(), "2fa: generate recovery", "error", err)
248 h.d.Render.HTTPError(w, r, http.StatusInternalServerError, "")
249 return
250 }
251
252 tx, err := h.d.Pool.Begin(r.Context())
253 if err != nil {
254 h.d.Render.HTTPError(w, r, http.StatusInternalServerError, "")
255 return
256 }
257 defer func() { _ = tx.Rollback(r.Context()) }()
258
259 // ConfirmUserTOTP only updates when confirmed_at IS NULL — handles the
260 // parallel-enrollment race; a second submit finds rows-affected==0.
261 rows, err := h.q.ConfirmUserTOTP(r.Context(), tx, usersdb.ConfirmUserTOTPParams{
262 UserID: user.ID,
263 LastUsedCounter: step,
264 })
265 if err != nil {
266 h.d.Render.HTTPError(w, r, http.StatusInternalServerError, "")
267 return
268 }
269 if rows == 0 {
270 // Already confirmed by a parallel request.
271 http.Redirect(w, r, "/settings/security/2fa/disable", http.StatusSeeOther)
272 return
273 }
274
275 if err := h.q.DeleteUserRecoveryCodes(r.Context(), tx, user.ID); err != nil {
276 h.d.Render.HTTPError(w, r, http.StatusInternalServerError, "")
277 return
278 }
279 for _, hsh := range hashes {
280 if err := h.q.InsertRecoveryCode(r.Context(), tx, usersdb.InsertRecoveryCodeParams{
281 UserID: user.ID, CodeHash: hsh,
282 }); err != nil {
283 h.d.Render.HTTPError(w, r, http.StatusInternalServerError, "")
284 return
285 }
286 }
287 if err := h.d.Audit.Record(r.Context(), tx, user.ID,
288 audit.Action2FAEnabled, audit.TargetUser, user.ID, nil); err != nil {
289 h.d.Render.HTTPError(w, r, http.StatusInternalServerError, "")
290 return
291 }
292 if err := h.d.Audit.Record(r.Context(), tx, user.ID,
293 audit.ActionRecoveryCodesIssued, audit.TargetUser, user.ID, map[string]any{"count": len(codes)}); err != nil {
294 h.d.Render.HTTPError(w, r, http.StatusInternalServerError, "")
295 return
296 }
297 if err := tx.Commit(r.Context()); err != nil {
298 h.d.Render.HTTPError(w, r, http.StatusInternalServerError, "")
299 return
300 }
301
302 // Just verified a fresh TOTP — gate-passing window starts now.
303 if s := middleware.SessionFromContext(r.Context()); s != nil {
304 s.Recent2FAAt = time.Now().Unix()
305 _ = h.d.SessionStore.Save(w, r, s)
306 }
307
308 h.notifyUser(r.Context(), user.ID, "2fa_enabled")
309
310 render("", codes)
311 }
312
313 // =============================== disable ================================
314
315 func (h *Handlers) twoFactorDisableForm(w http.ResponseWriter, r *http.Request) {
316 h.renderPage(w, r, "settings/2fa_disable", map[string]any{
317 "Title": "Disable two-factor authentication",
318 "CSRFToken": middleware.CSRFTokenForRequest(r),
319 "SettingsActive": "2fa",
320 })
321 }
322
323 func (h *Handlers) twoFactorDisableSubmit(w http.ResponseWriter, r *http.Request) {
324 if err := r.ParseForm(); err != nil {
325 h.d.Render.HTTPError(w, r, http.StatusBadRequest, "form parse")
326 return
327 }
328 user := middleware.CurrentUserFromContext(r.Context())
329 pw := r.PostFormValue("password")
330 code := strings.TrimSpace(r.PostFormValue("code"))
331
332 render := func(msg string) {
333 h.renderPage(w, r, "settings/2fa_disable", map[string]any{
334 "Title": "Disable two-factor authentication",
335 "CSRFToken": middleware.CSRFTokenForRequest(r),
336 "SettingsActive": "2fa",
337 "Error": msg,
338 })
339 }
340
341 if ok, err := h.confirmPasswordAndTOTP(r.Context(), user.ID, pw, code); err != nil {
342 h.d.Logger.ErrorContext(r.Context(), "2fa: disable confirm", "error", err)
343 h.d.Render.HTTPError(w, r, http.StatusInternalServerError, "")
344 return
345 } else if !ok {
346 render("Password or code incorrect. Please try again.")
347 return
348 }
349
350 tx, err := h.d.Pool.Begin(r.Context())
351 if err != nil {
352 h.d.Render.HTTPError(w, r, http.StatusInternalServerError, "")
353 return
354 }
355 defer func() { _ = tx.Rollback(r.Context()) }()
356
357 if err := h.q.DeleteUserTOTP(r.Context(), tx, user.ID); err != nil {
358 h.d.Render.HTTPError(w, r, http.StatusInternalServerError, "")
359 return
360 }
361 if err := h.q.DeleteUserRecoveryCodes(r.Context(), tx, user.ID); err != nil {
362 h.d.Render.HTTPError(w, r, http.StatusInternalServerError, "")
363 return
364 }
365 if err := h.d.Audit.Record(r.Context(), tx, user.ID,
366 audit.Action2FADisabled, audit.TargetUser, user.ID, nil); err != nil {
367 h.d.Render.HTTPError(w, r, http.StatusInternalServerError, "")
368 return
369 }
370 if err := tx.Commit(r.Context()); err != nil {
371 h.d.Render.HTTPError(w, r, http.StatusInternalServerError, "")
372 return
373 }
374
375 h.notifyUser(r.Context(), user.ID, "2fa_disabled")
376 http.Redirect(w, r, "/settings/security/2fa/enable?notice=disabled", http.StatusSeeOther)
377 }
378
379 // ============================== regenerate ==============================
380
381 func (h *Handlers) twoFactorRegenerateSubmit(w http.ResponseWriter, r *http.Request) {
382 if err := r.ParseForm(); err != nil {
383 h.d.Render.HTTPError(w, r, http.StatusBadRequest, "form parse")
384 return
385 }
386 user := middleware.CurrentUserFromContext(r.Context())
387 pw := r.PostFormValue("password")
388 code := strings.TrimSpace(r.PostFormValue("code"))
389
390 if ok, err := h.confirmPasswordAndTOTP(r.Context(), user.ID, pw, code); err != nil {
391 h.d.Logger.ErrorContext(r.Context(), "2fa: regen confirm", "error", err)
392 h.d.Render.HTTPError(w, r, http.StatusInternalServerError, "")
393 return
394 } else if !ok {
395 h.d.Render.HTTPError(w, r, http.StatusUnauthorized, "Password or code incorrect")
396 return
397 }
398
399 codes, hashes, err := totp.GenerateRecoveryCodes()
400 if err != nil {
401 h.d.Render.HTTPError(w, r, http.StatusInternalServerError, "")
402 return
403 }
404
405 tx, err := h.d.Pool.Begin(r.Context())
406 if err != nil {
407 h.d.Render.HTTPError(w, r, http.StatusInternalServerError, "")
408 return
409 }
410 defer func() { _ = tx.Rollback(r.Context()) }()
411
412 if err := h.q.DeleteUserRecoveryCodes(r.Context(), tx, user.ID); err != nil {
413 h.d.Render.HTTPError(w, r, http.StatusInternalServerError, "")
414 return
415 }
416 for _, hsh := range hashes {
417 if err := h.q.InsertRecoveryCode(r.Context(), tx, usersdb.InsertRecoveryCodeParams{
418 UserID: user.ID, CodeHash: hsh,
419 }); err != nil {
420 h.d.Render.HTTPError(w, r, http.StatusInternalServerError, "")
421 return
422 }
423 }
424 if err := h.d.Audit.Record(r.Context(), tx, user.ID,
425 audit.ActionRecoveryRegenerated, audit.TargetUser, user.ID, map[string]any{"count": len(codes)}); err != nil {
426 h.d.Render.HTTPError(w, r, http.StatusInternalServerError, "")
427 return
428 }
429 if err := tx.Commit(r.Context()); err != nil {
430 h.d.Render.HTTPError(w, r, http.StatusInternalServerError, "")
431 return
432 }
433
434 h.notifyUser(r.Context(), user.ID, "recovery_regenerated")
435
436 h.renderPage(w, r, "settings/2fa_recovery", map[string]any{
437 "Title": "New recovery codes",
438 "CSRFToken": middleware.CSRFTokenForRequest(r),
439 "SettingsActive": "2fa",
440 "RecoveryCodes": codes,
441 })
442 }
443
444 // =============================== helpers =================================
445
446 // verifyTOTPCode verifies code against the user's confirmed TOTP secret
447 // AND advances last_used_counter atomically (counter anti-replay).
448 func (h *Handlers) verifyTOTPCode(ctx context.Context, userID int64, code string) (bool, error) {
449 row, err := h.q.GetUserTOTP(ctx, h.d.Pool, userID)
450 if err != nil {
451 return false, nil // no enrollment → reject without leaking
452 }
453 if !row.ConfirmedAt.Valid {
454 return false, nil
455 }
456 secret, err := h.d.SecretBox.Open(row.SecretEncrypted, row.SecretNonce)
457 if err != nil {
458 return false, fmt.Errorf("open secret: %w", err)
459 }
460 step, err := totp.Verify(secret, code, time.Now())
461 if err != nil {
462 return false, nil
463 }
464 rows, err := h.q.BumpTOTPCounter(ctx, h.d.Pool, usersdb.BumpTOTPCounterParams{
465 UserID: userID,
466 LastUsedCounter: step,
467 })
468 if err != nil {
469 return false, fmt.Errorf("bump counter: %w", err)
470 }
471 return rows == 1, nil // rows==0 means counter replay → reject
472 }
473
474 // consumeRecoveryCode hashes the typed code and tries to mark it used.
475 // Returns true iff exactly one matching unused row was found.
476 func (h *Handlers) consumeRecoveryCode(ctx context.Context, userID int64, code string) (bool, error) {
477 hash := totp.HashRecoveryCode(code)
478 rows, err := h.q.ConsumeRecoveryCode(ctx, h.d.Pool, usersdb.ConsumeRecoveryCodeParams{
479 UserID: userID, CodeHash: hash,
480 })
481 if err != nil {
482 return false, err
483 }
484 return rows == 1, nil
485 }
486
487 // confirmPasswordAndTOTP validates current password AND current TOTP
488 // before sensitive 2FA state changes. Returns (true, nil) on success,
489 // (false, nil) on a clean rejection, or (false, err) on a real error.
490 func (h *Handlers) confirmPasswordAndTOTP(ctx context.Context, userID int64, pw, code string) (bool, error) {
491 user, err := h.q.GetUserByID(ctx, h.d.Pool, userID)
492 if err != nil {
493 return false, err
494 }
495 ok, err := password.Verify(pw, user.PasswordHash)
496 if err != nil {
497 return false, err
498 }
499 if !ok {
500 return false, nil
501 }
502 if codeOK, err := h.verifyTOTPCode(ctx, userID, code); err != nil {
503 return false, err
504 } else if !codeOK {
505 return false, nil
506 }
507 return true, nil
508 }
509
510 // notifyUser sends a notification email about a 2FA state change. Best
511 // effort — failure is logged but does not break the flow.
512 func (h *Handlers) notifyUser(ctx context.Context, userID int64, kind string) {
513 user, err := h.q.GetUserByID(ctx, h.d.Pool, userID)
514 if err != nil {
515 return
516 }
517 if !user.PrimaryEmailID.Valid {
518 return
519 }
520 em, err := h.q.GetUserEmailByID(ctx, h.d.Pool, user.PrimaryEmailID.Int64)
521 if err != nil {
522 return
523 }
524 msg, err := email.NoticeMessage(h.d.Branding, string(em.Email), user.Username, kind)
525 if err != nil {
526 h.d.Logger.WarnContext(ctx, "notice: build", "kind", kind, "error", err)
527 return
528 }
529 if err := h.d.Email.Send(ctx, msg); err != nil {
530 h.d.Logger.WarnContext(ctx, "notice: send", "kind", kind, "error", err)
531 }
532 }
533
534 // silence unused import warnings if guards are removed.
535 var (
536 _ = pgx.ErrNoRows
537 _ = pgtype.Int8{}
538 _ = errors.New
539 )
540