Go · 9002 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 package auth_test
4
5 import (
6 "context"
7 "encoding/json"
8 "io/fs"
9 "log/slog"
10 "net/http"
11 "net/http/httptest"
12 "net/url"
13 "strings"
14 "testing"
15 "testing/fstest"
16
17 "github.com/go-chi/chi/v5"
18 "github.com/jackc/pgx/v5/pgxpool"
19
20 "github.com/tenseleyFlow/shithub/internal/auth/audit"
21 "github.com/tenseleyFlow/shithub/internal/auth/devicecode"
22 "github.com/tenseleyFlow/shithub/internal/auth/email"
23 "github.com/tenseleyFlow/shithub/internal/auth/password"
24 "github.com/tenseleyFlow/shithub/internal/auth/session"
25 "github.com/tenseleyFlow/shithub/internal/auth/throttle"
26 "github.com/tenseleyFlow/shithub/internal/testing/dbtest"
27 usersdb "github.com/tenseleyFlow/shithub/internal/users/sqlc"
28 authh "github.com/tenseleyFlow/shithub/internal/web/handlers/auth"
29 "github.com/tenseleyFlow/shithub/internal/web/render"
30 )
31
32 // newDeviceCodeRouter wires only the JSON device-code endpoints on a
33 // bare router (no CSRF, no session loader) — these endpoints are
34 // always CSRF-exempt in production. We use a minimal templates FS
35 // because the JSON handlers never render HTML.
36 func newDeviceCodeRouter(t *testing.T) (http.Handler, *pgxpool.Pool) {
37 t.Helper()
38 pool := dbtest.NewTestDB(t)
39 tmplFS := fstest.MapFS{
40 "_layout.html": {Data: []byte(`{{ define "layout" }}{{ template "page" . }}{{ end }}`)},
41 "hello.html": {Data: []byte(`{{ define "page" }}home{{ end }}`)},
42 }
43 rr, err := render.New(fs.FS(tmplFS), render.Options{})
44 if err != nil {
45 t.Fatalf("render.New: %v", err)
46 }
47 storeKey, err := session.GenerateKey()
48 if err != nil {
49 t.Fatalf("session key: %v", err)
50 }
51 store, err := session.NewCookieStore(session.CookieStoreConfig{Key: storeKey, MaxAge: 0, Secure: false})
52 if err != nil {
53 t.Fatalf("NewCookieStore: %v", err)
54 }
55 h, err := authh.New(authh.Deps{
56 Logger: slog.Default(),
57 Render: rr,
58 Pool: pool,
59 SessionStore: store,
60 Email: &noopSender{},
61 Branding: email.Branding{
62 SiteName: "shithub", BaseURL: "http://test.invalid",
63 From: "noreply@shithub.test",
64 },
65 Argon2: password.Params{
66 Memory: 1024, Time: 1, Threads: 1, SaltLen: 16, KeyLen: 32,
67 },
68 Limiter: throttle.NewLimiter(),
69 Audit: audit.NewRecorder(),
70 DeviceCode: devicecode.Config{
71 ClientIDs: []string{"shithub-cli"},
72 DefaultScopes: []string{"user:read"},
73 },
74 })
75 if err != nil {
76 t.Fatalf("authh.New: %v", err)
77 }
78 r := chi.NewRouter()
79 h.MountDeviceCodeAPI(r)
80 return r, pool
81 }
82
83 type noopSender struct{}
84
85 func (noopSender) Send(ctx context.Context, msg email.Message) error { return nil }
86
87 // seedDeviceUser inserts a user row that the Approve path can stamp
88 // as the authorizer. The hash is a constant argon2id digest; we don't
89 // log this user in via the password flow so the value doesn't have to
90 // match anything meaningful.
91 const seedDeviceUserHash = "$argon2id$v=19$m=1024,t=1,p=1$YWFhYWFhYWFhYWFhYWFhYQ$" +
92 "DvBOTSnFhCBe+Pfx/W7Sk3hG3JCm2Wj0RBgCu+CPDtY"
93
94 func seedDeviceUser(t *testing.T, pool *pgxpool.Pool, username string) int64 {
95 t.Helper()
96 q := usersdb.New()
97 u, err := q.CreateUser(context.Background(), pool, usersdb.CreateUserParams{
98 Username: username,
99 DisplayName: strings.ToUpper(username[:1]) + username[1:],
100 PasswordHash: seedDeviceUserHash,
101 })
102 if err != nil {
103 t.Fatalf("CreateUser: %v", err)
104 }
105 em, err := q.CreateUserEmail(context.Background(), pool, usersdb.CreateUserEmailParams{
106 UserID: u.ID,
107 Email: username + "@example.test",
108 IsPrimary: true,
109 })
110 if err != nil {
111 t.Fatalf("CreateUserEmail: %v", err)
112 }
113 if err := q.MarkUserEmailVerified(context.Background(), pool, em.ID); err != nil {
114 t.Fatalf("MarkUserEmailVerified: %v", err)
115 }
116 return u.ID
117 }
118
119 func formPost(router http.Handler, path string, body url.Values) *httptest.ResponseRecorder {
120 req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(body.Encode()))
121 req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
122 rr := httptest.NewRecorder()
123 router.ServeHTTP(rr, req)
124 return rr
125 }
126
127 type oauthErr struct {
128 Error string `json:"error"`
129 ErrorDescription string `json:"error_description"`
130 }
131
132 type deviceIssue struct {
133 DeviceCode string `json:"device_code"`
134 UserCode string `json:"user_code"`
135 VerificationURI string `json:"verification_uri"`
136 VerificationURIComplete string `json:"verification_uri_complete"`
137 ExpiresIn int `json:"expires_in"`
138 Interval int `json:"interval"`
139 }
140
141 type deviceExchange struct {
142 AccessToken string `json:"access_token"`
143 TokenType string `json:"token_type"`
144 Scope string `json:"scope"`
145 }
146
147 func TestDeviceAPI_IssueShape(t *testing.T) {
148 router, _ := newDeviceCodeRouter(t)
149 rr := formPost(router, "/login/device/code", url.Values{
150 "client_id": {"shithub-cli"},
151 "scope": {"user:read repo:read"},
152 })
153 if rr.Code != http.StatusOK {
154 t.Fatalf("status: got %d; body=%s", rr.Code, rr.Body.String())
155 }
156 var out deviceIssue
157 if err := json.Unmarshal(rr.Body.Bytes(), &out); err != nil {
158 t.Fatalf("decode: %v", err)
159 }
160 if out.DeviceCode == "" || out.UserCode == "" {
161 t.Fatalf("empty codes: %+v", out)
162 }
163 if out.VerificationURI != "http://test.invalid/login/device" {
164 t.Errorf("verification_uri: got %q", out.VerificationURI)
165 }
166 if !strings.Contains(out.VerificationURIComplete, "user_code=") {
167 t.Errorf("verification_uri_complete missing user_code: %q", out.VerificationURIComplete)
168 }
169 if out.ExpiresIn <= 0 || out.Interval <= 0 {
170 t.Errorf("expiry/interval not populated: %+v", out)
171 }
172 }
173
174 func TestDeviceAPI_IssueRejectsUnknownClient(t *testing.T) {
175 router, _ := newDeviceCodeRouter(t)
176 rr := formPost(router, "/login/device/code", url.Values{"client_id": {"evil-cli"}})
177 if rr.Code != http.StatusBadRequest {
178 t.Fatalf("status: got %d; body=%s", rr.Code, rr.Body.String())
179 }
180 var oe oauthErr
181 _ = json.Unmarshal(rr.Body.Bytes(), &oe)
182 if oe.Error != "unauthorized_client" {
183 t.Errorf("error code: got %q", oe.Error)
184 }
185 }
186
187 func TestDeviceAPI_ExchangeRejectsWrongGrantType(t *testing.T) {
188 router, _ := newDeviceCodeRouter(t)
189 rr := formPost(router, "/login/oauth/access_token", url.Values{
190 "grant_type": {"authorization_code"},
191 "client_id": {"shithub-cli"},
192 })
193 if rr.Code != http.StatusBadRequest {
194 t.Fatalf("status: got %d; body=%s", rr.Code, rr.Body.String())
195 }
196 var oe oauthErr
197 _ = json.Unmarshal(rr.Body.Bytes(), &oe)
198 if oe.Error != "unsupported_grant_type" {
199 t.Errorf("error: got %q", oe.Error)
200 }
201 }
202
203 func TestDeviceAPI_ExchangePendingThenApproved(t *testing.T) {
204 router, pool := newDeviceCodeRouter(t)
205 userID := seedDeviceUser(t, pool, "alice")
206
207 // 1) Issue a code via HTTP.
208 rr := formPost(router, "/login/device/code", url.Values{
209 "client_id": {"shithub-cli"},
210 "scope": {"user:read"},
211 })
212 if rr.Code != http.StatusOK {
213 t.Fatalf("issue status: got %d; body=%s", rr.Code, rr.Body.String())
214 }
215 var issued deviceIssue
216 _ = json.Unmarshal(rr.Body.Bytes(), &issued)
217
218 // 2) Pending poll → 400 authorization_pending.
219 const exchangeGrant = "urn:ietf:params:oauth:grant-type:device_code"
220 rr = formPost(router, "/login/oauth/access_token", url.Values{
221 "grant_type": {exchangeGrant},
222 "client_id": {"shithub-cli"},
223 "device_code": {issued.DeviceCode},
224 })
225 if rr.Code != http.StatusBadRequest {
226 t.Fatalf("pending status: got %d; body=%s", rr.Code, rr.Body.String())
227 }
228 var oe oauthErr
229 _ = json.Unmarshal(rr.Body.Bytes(), &oe)
230 if oe.Error != "authorization_pending" {
231 t.Errorf("pending error: got %q", oe.Error)
232 }
233
234 // 3) Approve directly via the orchestrator (HTML form flow would do
235 // this via the verification page; we shortcut here since this test
236 // is about the JSON exchange path).
237 row, err := devicecode.LookupByUserCode(context.Background(), devicecode.Deps{Pool: pool}, issued.UserCode)
238 if err != nil {
239 t.Fatalf("LookupByUserCode: %v", err)
240 }
241 if err := devicecode.Approve(context.Background(), devicecode.Deps{Pool: pool}, row.ID, userID); err != nil {
242 t.Fatalf("Approve: %v", err)
243 }
244 // Rewind last_polled_at so the slow_down gate doesn't bite the next exchange.
245 if _, err := pool.Exec(context.Background(),
246 "UPDATE device_authorizations SET last_polled_at = now() - interval '10 seconds' WHERE id = $1",
247 row.ID); err != nil {
248 t.Fatalf("rewind last_polled_at: %v", err)
249 }
250
251 // 4) Exchange after approval → 200 + access_token.
252 rr = formPost(router, "/login/oauth/access_token", url.Values{
253 "grant_type": {exchangeGrant},
254 "client_id": {"shithub-cli"},
255 "device_code": {issued.DeviceCode},
256 })
257 if rr.Code != http.StatusOK {
258 t.Fatalf("exchange status: got %d; body=%s", rr.Code, rr.Body.String())
259 }
260 var ex deviceExchange
261 _ = json.Unmarshal(rr.Body.Bytes(), &ex)
262 if !strings.HasPrefix(ex.AccessToken, "shithub_pat_") {
263 t.Errorf("access_token: got %q", ex.AccessToken)
264 }
265 if ex.TokenType != "bearer" {
266 t.Errorf("token_type: got %q", ex.TokenType)
267 }
268 if ex.Scope != "user:read" {
269 t.Errorf("scope: got %q", ex.Scope)
270 }
271 }
272