tenseleyflow/shithub / 639aace

Browse files

Add 2FA integration tests: enroll-login-challenge, recovery one-time, counter replay, disable confirm

Authored by mfwolffe <wolffemf@dukes.jmu.edu>
SHA
639aaceac768e5f6827cdd794b1485488b0d808e
Parents
40539d7
Tree
bf8a4fa

2 changed files

StatusFile+-
M internal/web/handlers/auth/auth_test.go 42 11
A internal/web/handlers/auth/twofactor_test.go 287 0
internal/web/handlers/auth/auth_test.gomodified
@@ -24,6 +24,7 @@ import (
2424
 
2525
 	"github.com/tenseleyFlow/shithub/internal/auth/email"
2626
 	"github.com/tenseleyFlow/shithub/internal/auth/password"
27
+	"github.com/tenseleyFlow/shithub/internal/auth/secretbox"
2728
 	"github.com/tenseleyFlow/shithub/internal/auth/session"
2829
 	"github.com/tenseleyFlow/shithub/internal/auth/throttle"
2930
 	"github.com/tenseleyFlow/shithub/internal/testing/dbtest"
@@ -82,6 +83,15 @@ func newTestServer(t *testing.T, requireVerify bool) (*httptest.Server, *capture
8283
 	cap := &captureSender{}
8384
 	logger := slog.New(slog.NewTextHandler(io.Discard, nil))
8485
 
86
+	totpKey, err := secretbox.GenerateKey()
87
+	if err != nil {
88
+		t.Fatalf("secretbox key: %v", err)
89
+	}
90
+	box, err := secretbox.FromBytes(totpKey)
91
+	if err != nil {
92
+		t.Fatalf("secretbox: %v", err)
93
+	}
94
+
8595
 	h, err := authh.New(authh.Deps{
8696
 		Logger:       logger,
8797
 		Render:       rr,
@@ -95,6 +105,7 @@ func newTestServer(t *testing.T, requireVerify bool) (*httptest.Server, *capture
95105
 		Argon2:                   fastArgon,
96106
 		Limiter:                  throttle.NewLimiter(),
97107
 		RequireEmailVerification: requireVerify,
108
+		SecretBox:                box,
98109
 	})
99110
 	if err != nil {
100111
 		t.Fatalf("authh.New: %v", err)
@@ -104,6 +115,18 @@ func newTestServer(t *testing.T, requireVerify bool) (*httptest.Server, *capture
104115
 	r.Use(middleware.RequestID)
105116
 	r.Use(middleware.RealIP(middleware.RealIPConfig{}))
106117
 	r.Use(middleware.SessionLoader(store, logger))
118
+	r.Use(middleware.OptionalUser(func(ctx context.Context, id int64) (string, error) {
119
+		// Cheap username lookup against the test pool — RequireUser only
120
+		// checks ID == 0, but settings handlers use the username.
121
+		u, err := pool.Acquire(ctx)
122
+		if err != nil {
123
+			return "", err
124
+		}
125
+		defer u.Release()
126
+		var name string
127
+		err = u.QueryRow(ctx, "SELECT username FROM users WHERE id = $1", id).Scan(&name)
128
+		return name, err
129
+	}))
107130
 	csrf := middleware.CSRF(middleware.CSRFConfig{
108131
 		FailureHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
109132
 			http.Error(w, "csrf: "+nosurfReason(r), http.StatusForbidden)
@@ -130,19 +153,27 @@ func authTemplatesFS() fs.FS {
130153
 	resetReq := `{{ define "page" }}<form>{{ with .Notice }}<p class=notice>{{.}}</p>{{ end }}<input name=csrf_token value="{{.CSRFToken}}"></form>{{ end }}`
131154
 	resetConf := `{{ define "page" }}<form>{{ with .Error }}<p class=error>{{.}}</p>{{ end }}<input name=csrf_token value="{{.CSRFToken}}">{{.Token}}</form>{{ end }}`
132155
 	verifyResend := `{{ define "page" }}<form>{{ with .Notice }}<p class=notice>{{.}}</p>{{ end }}<input name=csrf_token value="{{.CSRFToken}}"></form>{{ end }}`
156
+	tfaChallenge := `{{ define "page" }}<form>{{ with .Error }}<p class=error>{{.}}</p>{{ end }}<input name=csrf_token value="{{.CSRFToken}}"><input name=next value="{{.Next}}"></form>{{ end }}`
157
+	tfaEnable := `{{ define "page" }}<form>{{ with .Error }}<p class=error>{{.}}</p>{{ end }}<input name=csrf_token value="{{.CSRFToken}}">SECRET={{.Secret}}</form>{{ end }}`
158
+	tfaDisable := `{{ define "page" }}<form>{{ with .Error }}<p class=error>{{.}}</p>{{ end }}<input name=csrf_token value="{{.CSRFToken}}"></form>{{ end }}`
159
+	tfaRecovery := `{{ define "page" }}<form>{{ with .Error }}<p class=error>{{.}}</p>{{ end }}<input name=csrf_token value="{{.CSRFToken}}">{{ if .RecoveryCodes }}CODES={{ range .RecoveryCodes }}{{.}};{{ end }}{{ end }}</form>{{ end }}`
133160
 	errorPage := `{{ define "page" }}<h1>{{.Status}} {{.StatusText}}</h1><p>{{.Message}}</p>{{ end }}`
134161
 	return fstest.MapFS{
135
-		"_layout.html":            {Data: []byte(layout)},
136
-		"hello.html":              {Data: []byte(`{{ define "page" }}home{{ end }}`)},
137
-		"auth/signup.html":        {Data: []byte(signup)},
138
-		"auth/login.html":         {Data: []byte(login)},
139
-		"auth/reset_request.html": {Data: []byte(resetReq)},
140
-		"auth/reset_confirm.html": {Data: []byte(resetConf)},
141
-		"auth/verify_resend.html": {Data: []byte(verifyResend)},
142
-		"errors/404.html":         {Data: []byte(errorPage)},
143
-		"errors/403.html":         {Data: []byte(errorPage)},
144
-		"errors/429.html":         {Data: []byte(errorPage)},
145
-		"errors/500.html":         {Data: []byte(errorPage)},
162
+		"_layout.html":               {Data: []byte(layout)},
163
+		"hello.html":                 {Data: []byte(`{{ define "page" }}home{{ end }}`)},
164
+		"auth/signup.html":           {Data: []byte(signup)},
165
+		"auth/login.html":            {Data: []byte(login)},
166
+		"auth/reset_request.html":    {Data: []byte(resetReq)},
167
+		"auth/reset_confirm.html":    {Data: []byte(resetConf)},
168
+		"auth/verify_resend.html":    {Data: []byte(verifyResend)},
169
+		"auth/2fa_challenge.html":    {Data: []byte(tfaChallenge)},
170
+		"settings/2fa_enable.html":   {Data: []byte(tfaEnable)},
171
+		"settings/2fa_disable.html":  {Data: []byte(tfaDisable)},
172
+		"settings/2fa_recovery.html": {Data: []byte(tfaRecovery)},
173
+		"errors/404.html":            {Data: []byte(errorPage)},
174
+		"errors/403.html":            {Data: []byte(errorPage)},
175
+		"errors/429.html":            {Data: []byte(errorPage)},
176
+		"errors/500.html":            {Data: []byte(errorPage)},
146177
 	}
147178
 }
148179
 
internal/web/handlers/auth/twofactor_test.goadded
@@ -0,0 +1,287 @@
1
+// SPDX-License-Identifier: AGPL-3.0-or-later
2
+
3
+package auth_test
4
+
5
+import (
6
+	"encoding/base32"
7
+	"io"
8
+	"net/http"
9
+	"net/url"
10
+	"regexp"
11
+	"strings"
12
+	"testing"
13
+	"time"
14
+
15
+	"github.com/tenseleyFlow/shithub/internal/auth/totp"
16
+)
17
+
18
+// base32StdLowerStrict is the standard base32 alphabet used by EncodeBase32.
19
+var base32StdLowerStrict = base32.StdEncoding
20
+
21
+// enrollTOTPHelper completes the signup → enable-2FA flow and returns
22
+// everything subsequent tests need to assert against the live server.
23
+func enrollTOTPHelper(t *testing.T, requireVerify bool) (
24
+	cli *client,
25
+	captor *captureSender,
26
+	username string,
27
+	password string,
28
+	secret []byte,
29
+	recovery []string,
30
+) {
31
+	t.Helper()
32
+	httpsrv, cap := newTestServer(t, requireVerify)
33
+	captor = cap
34
+	cli = newClient(t, httpsrv)
35
+
36
+	username = "alice2fa"
37
+	password = "correct horse battery staple"
38
+	mustSignup(t, cli, username, "alice2fa@example.com", password)
39
+
40
+	// Verify email so login works regardless of requireVerify.
41
+	tok := extractTokenFromMessage(t, captor.all()[0], "/verify-email")
42
+	resp := cli.get(t, "/verify-email/"+tok)
43
+	_ = resp.Body.Close()
44
+
45
+	// First login — no 2FA yet, so this lands on / .
46
+	csrf := cli.extractCSRF(t, "/login")
47
+	resp = cli.post(t, "/login", url.Values{
48
+		"csrf_token": {csrf}, "username": {username}, "password": {password},
49
+	})
50
+	if resp.StatusCode != http.StatusSeeOther {
51
+		body, _ := io.ReadAll(resp.Body)
52
+		t.Fatalf("first login: %d %s", resp.StatusCode, body)
53
+	}
54
+	_ = resp.Body.Close()
55
+
56
+	// GET enable form — server mints a fresh secret and sends it inline
57
+	// (test template emits SECRET=<base32> in the body).
58
+	resp = cli.get(t, "/settings/security/2fa/enable")
59
+	if resp.StatusCode != http.StatusOK {
60
+		body, _ := io.ReadAll(resp.Body)
61
+		t.Fatalf("2fa enable form: %d %s", resp.StatusCode, body)
62
+	}
63
+	body, _ := io.ReadAll(resp.Body)
64
+	_ = resp.Body.Close()
65
+	secretRE := regexp.MustCompile(`SECRET=([A-Z2-7]+)`)
66
+	m := secretRE.FindStringSubmatch(string(body))
67
+	if m == nil {
68
+		t.Fatalf("no secret in enable form body: %s", body)
69
+	}
70
+	secret = decodeBase32(t, m[1])
71
+
72
+	// Compute current TOTP code from the secret and submit it.
73
+	csrf = extractCSRFFromBody(t, body)
74
+	code, err := totp.Generate(secret, time.Now())
75
+	if err != nil {
76
+		t.Fatalf("generate code: %v", err)
77
+	}
78
+	resp = cli.post(t, "/settings/security/2fa/enable", url.Values{
79
+		"csrf_token": {csrf}, "code": {code},
80
+	})
81
+	if resp.StatusCode != http.StatusOK {
82
+		body2, _ := io.ReadAll(resp.Body)
83
+		t.Fatalf("2fa confirm: %d %s", resp.StatusCode, body2)
84
+	}
85
+	body, _ = io.ReadAll(resp.Body)
86
+	_ = resp.Body.Close()
87
+	codesRE := regexp.MustCompile(`CODES=([A-Z0-9\-;]+)`)
88
+	cm := codesRE.FindStringSubmatch(string(body))
89
+	if cm == nil {
90
+		t.Fatalf("no recovery codes in body: %s", body)
91
+	}
92
+	recovery = strings.Split(strings.TrimSuffix(cm[1], ";"), ";")
93
+
94
+	return cli, captor, username, password, secret, recovery
95
+}
96
+
97
+// extractCSRFFromBody finds the csrf token in an already-fetched body.
98
+func extractCSRFFromBody(t *testing.T, body []byte) string {
99
+	t.Helper()
100
+	m := csrfMarkerRE.FindStringSubmatch(string(body))
101
+	if m == nil {
102
+		t.Fatalf("no csrf marker in body: %s", body)
103
+	}
104
+	return htmlUnescape(m[1])
105
+}
106
+
107
+// htmlUnescape mirrors what extractCSRF does internally.
108
+func htmlUnescape(s string) string {
109
+	return strings.NewReplacer("&#43;", "+", "&#47;", "/", "&#61;", "=").Replace(s)
110
+}
111
+
112
+func decodeBase32(t *testing.T, s string) []byte {
113
+	t.Helper()
114
+	// pad to multiple of 8.
115
+	for len(s)%8 != 0 {
116
+		s += "="
117
+	}
118
+	b, err := base32StdLowerStrict.DecodeString(s)
119
+	if err != nil {
120
+		t.Fatalf("base32 decode: %v", err)
121
+	}
122
+	return b
123
+}
124
+
125
+// ============================== tests ==================================
126
+
127
+func TestTwoFactor_Enroll_Logout_Login_Challenge_FullSession(t *testing.T) {
128
+	t.Parallel()
129
+	cli, _, username, password, secret, _ := enrollTOTPHelper(t, false)
130
+
131
+	// Logout.
132
+	csrf := cli.extractCSRF(t, "/login")
133
+	resp := cli.post(t, "/logout", url.Values{"csrf_token": {csrf}})
134
+	if resp.StatusCode != http.StatusSeeOther {
135
+		t.Fatalf("logout: %d", resp.StatusCode)
136
+	}
137
+	_ = resp.Body.Close()
138
+
139
+	// Re-login with password — should redirect to /login/2fa.
140
+	csrf = cli.extractCSRF(t, "/login")
141
+	resp = cli.post(t, "/login", url.Values{
142
+		"csrf_token": {csrf}, "username": {username}, "password": {password},
143
+	})
144
+	if resp.StatusCode != http.StatusSeeOther {
145
+		body, _ := io.ReadAll(resp.Body)
146
+		t.Fatalf("login after enroll: %d body=%s", resp.StatusCode, body)
147
+	}
148
+	if loc := resp.Header.Get("Location"); loc != "/login/2fa" {
149
+		t.Fatalf("expected redirect to /login/2fa, got %q", loc)
150
+	}
151
+	_ = resp.Body.Close()
152
+
153
+	// Challenge form, submit a TOTP from the NEXT step. Enrollment-confirm
154
+	// already advanced last_used_counter to the current step; with ±1
155
+	// skew, a code from now+30s decodes to step+1 and is accepted.
156
+	csrf = cli.extractCSRF(t, "/login/2fa")
157
+	code, _ := totp.Generate(secret, time.Now().Add(30*time.Second))
158
+	resp = cli.post(t, "/login/2fa", url.Values{"csrf_token": {csrf}, "code": {code}})
159
+	if resp.StatusCode != http.StatusSeeOther {
160
+		body, _ := io.ReadAll(resp.Body)
161
+		t.Fatalf("2fa challenge: %d body=%s", resp.StatusCode, body)
162
+	}
163
+	if loc := resp.Header.Get("Location"); loc != "/" {
164
+		t.Fatalf("expected / after challenge, got %q", loc)
165
+	}
166
+	_ = resp.Body.Close()
167
+}
168
+
169
+func TestTwoFactor_RecoveryCode_OneTimeUse(t *testing.T) {
170
+	t.Parallel()
171
+	cli, _, username, password, _, recovery := enrollTOTPHelper(t, false)
172
+	if len(recovery) == 0 {
173
+		t.Fatal("no recovery codes captured")
174
+	}
175
+	code := recovery[0]
176
+
177
+	// Logout.
178
+	csrf := cli.extractCSRF(t, "/login")
179
+	_ = cli.post(t, "/logout", url.Values{"csrf_token": {csrf}}).Body.Close()
180
+
181
+	// Login + recovery code → session upgraded.
182
+	csrf = cli.extractCSRF(t, "/login")
183
+	resp := cli.post(t, "/login", url.Values{
184
+		"csrf_token": {csrf}, "username": {username}, "password": {password},
185
+	})
186
+	_ = resp.Body.Close()
187
+	csrf = cli.extractCSRF(t, "/login/2fa")
188
+	resp = cli.post(t, "/login/2fa", url.Values{"csrf_token": {csrf}, "code": {code}})
189
+	if resp.StatusCode != http.StatusSeeOther {
190
+		body, _ := io.ReadAll(resp.Body)
191
+		t.Fatalf("recovery first use: %d %s", resp.StatusCode, body)
192
+	}
193
+	_ = resp.Body.Close()
194
+
195
+	// Logout, second use of the same code MUST fail.
196
+	csrf = cli.extractCSRF(t, "/login")
197
+	_ = cli.post(t, "/logout", url.Values{"csrf_token": {csrf}}).Body.Close()
198
+	csrf = cli.extractCSRF(t, "/login")
199
+	_ = cli.post(t, "/login", url.Values{
200
+		"csrf_token": {csrf}, "username": {username}, "password": {password},
201
+	}).Body.Close()
202
+	csrf = cli.extractCSRF(t, "/login/2fa")
203
+	resp = cli.post(t, "/login/2fa", url.Values{"csrf_token": {csrf}, "code": {code}})
204
+	if resp.StatusCode != http.StatusOK {
205
+		t.Fatalf("recovery second use: expected 200 (form re-render), got %d", resp.StatusCode)
206
+	}
207
+	_ = resp.Body.Close()
208
+}
209
+
210
+func TestTwoFactor_CounterReplayRejected(t *testing.T) {
211
+	t.Parallel()
212
+	cli, _, username, password, secret, _ := enrollTOTPHelper(t, false)
213
+
214
+	// Logout.
215
+	csrf := cli.extractCSRF(t, "/login")
216
+	_ = cli.post(t, "/logout", url.Values{"csrf_token": {csrf}}).Body.Close()
217
+
218
+	// First login + TOTP from the NEXT step (enrollment already used the
219
+	// current step's counter). With ±1 skew the server accepts now+30s.
220
+	code, _ := totp.Generate(secret, time.Now().Add(30*time.Second))
221
+	csrf = cli.extractCSRF(t, "/login")
222
+	_ = cli.post(t, "/login", url.Values{
223
+		"csrf_token": {csrf}, "username": {username}, "password": {password},
224
+	}).Body.Close()
225
+	csrf = cli.extractCSRF(t, "/login/2fa")
226
+	resp := cli.post(t, "/login/2fa", url.Values{"csrf_token": {csrf}, "code": {code}})
227
+	if resp.StatusCode != http.StatusSeeOther {
228
+		body, _ := io.ReadAll(resp.Body)
229
+		t.Fatalf("first 2fa: %d body=%s", resp.StatusCode, body)
230
+	}
231
+	_ = resp.Body.Close()
232
+
233
+	// Logout.
234
+	csrf = cli.extractCSRF(t, "/login")
235
+	_ = cli.post(t, "/logout", url.Values{"csrf_token": {csrf}}).Body.Close()
236
+
237
+	// Second login attempt within the SAME 30-second window — same code.
238
+	// Expect REJECTION (counter replay).
239
+	csrf = cli.extractCSRF(t, "/login")
240
+	_ = cli.post(t, "/login", url.Values{
241
+		"csrf_token": {csrf}, "username": {username}, "password": {password},
242
+	}).Body.Close()
243
+	csrf = cli.extractCSRF(t, "/login/2fa")
244
+	resp = cli.post(t, "/login/2fa", url.Values{"csrf_token": {csrf}, "code": {code}})
245
+	if resp.StatusCode != http.StatusOK {
246
+		t.Fatalf("counter-replay: expected 200 (rejected), got %d", resp.StatusCode)
247
+	}
248
+	body, _ := io.ReadAll(resp.Body)
249
+	_ = resp.Body.Close()
250
+	if !strings.Contains(string(body), "Incorrect code") {
251
+		t.Fatalf("expected 'Incorrect code' message, got: %s", body)
252
+	}
253
+}
254
+
255
+func TestTwoFactor_DisableRequiresPasswordAndTOTP(t *testing.T) {
256
+	t.Parallel()
257
+	cli, _, _, password, secret, _ := enrollTOTPHelper(t, false)
258
+
259
+	// Wrong password → form re-rendered with error.
260
+	csrf := cli.extractCSRF(t, "/settings/security/2fa/disable")
261
+	resp := cli.post(t, "/settings/security/2fa/disable", url.Values{
262
+		"csrf_token": {csrf},
263
+		"password":   {"wrong"},
264
+		"code":       {"000000"},
265
+	})
266
+	if resp.StatusCode != http.StatusOK {
267
+		body, _ := io.ReadAll(resp.Body)
268
+		t.Fatalf("disable wrong creds: %d %s", resp.StatusCode, body)
269
+	}
270
+	body, _ := io.ReadAll(resp.Body)
271
+	_ = resp.Body.Close()
272
+	if !strings.Contains(string(body), "incorrect") {
273
+		t.Fatalf("expected 'incorrect' in error, got: %s", body)
274
+	}
275
+
276
+	// Correct password + correct TOTP → succeed.
277
+	csrf = cli.extractCSRF(t, "/settings/security/2fa/disable")
278
+	code, _ := totp.Generate(secret, time.Now().Add(31*time.Second)) // ensure NEW counter step vs enrollment
279
+	resp = cli.post(t, "/settings/security/2fa/disable", url.Values{
280
+		"csrf_token": {csrf}, "password": {password}, "code": {code},
281
+	})
282
+	if resp.StatusCode != http.StatusSeeOther {
283
+		body2, _ := io.ReadAll(resp.Body)
284
+		t.Fatalf("disable: %d %s", resp.StatusCode, body2)
285
+	}
286
+	_ = resp.Body.Close()
287
+}