Go · 11109 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 package devicecode_test
4
5 import (
6 "context"
7 "errors"
8 "strings"
9 "testing"
10 "time"
11
12 "github.com/jackc/pgx/v5/pgtype"
13 "github.com/jackc/pgx/v5/pgxpool"
14
15 "github.com/tenseleyFlow/shithub/internal/auth/devicecode"
16 "github.com/tenseleyFlow/shithub/internal/testing/dbtest"
17 usersdb "github.com/tenseleyFlow/shithub/internal/users/sqlc"
18 )
19
20 // fixedArgon2Hash is a pre-computed argon2id hash of an arbitrary password.
21 // Using a constant avoids running argon2 in every test setup.
22 const fixedArgon2Hash = "$argon2id$v=19$m=1024,t=1,p=1$YWFhYWFhYWFhYWFhYWFhYQ$" +
23 "DvBOTSnFhCBe+Pfx/W7Sk3hG3JCm2Wj0RBgCu+CPDtY"
24
25 func seedUser(t *testing.T, pool *pgxpool.Pool, username string) int64 {
26 t.Helper()
27 q := usersdb.New()
28 u, err := q.CreateUser(context.Background(), pool, usersdb.CreateUserParams{
29 Username: username,
30 DisplayName: strings.ToUpper(username[:1]) + username[1:],
31 PasswordHash: fixedArgon2Hash,
32 })
33 if err != nil {
34 t.Fatalf("CreateUser: %v", err)
35 }
36 em, err := q.CreateUserEmail(context.Background(), pool, usersdb.CreateUserEmailParams{
37 UserID: u.ID,
38 Email: username + "@example.test",
39 IsPrimary: true,
40 })
41 if err != nil {
42 t.Fatalf("CreateUserEmail: %v", err)
43 }
44 if err := q.MarkUserEmailVerified(context.Background(), pool, em.ID); err != nil {
45 t.Fatalf("MarkUserEmailVerified: %v", err)
46 }
47 return u.ID
48 }
49
50 func defaultsForTest() devicecode.Config {
51 return devicecode.Config{
52 ClientIDs: []string{"shithub-cli"},
53 DefaultScopes: []string{"user:read"},
54 ExpiresIn: 15 * time.Minute,
55 PollInterval: 5 * time.Second,
56 }
57 }
58
59 func TestCreate_HappyPath(t *testing.T) {
60 pool := dbtest.NewTestDB(t)
61 deps := devicecode.Deps{Pool: pool}
62 auth, err := devicecode.Create(context.Background(), deps, defaultsForTest(), "shithub-cli", "user:read,repo:read")
63 if err != nil {
64 t.Fatalf("Create: %v", err)
65 }
66 if auth.DeviceCode == "" || auth.UserCode == "" {
67 t.Fatalf("Create returned empty codes: %+v", auth)
68 }
69 if !strings.Contains(auth.UserCode, "-") || len(auth.UserCode) != 9 {
70 t.Errorf("user_code shape: got %q", auth.UserCode)
71 }
72 if auth.ExpiresIn != 15*time.Minute {
73 t.Errorf("expires_in: got %s", auth.ExpiresIn)
74 }
75 if got, want := auth.Scopes, []string{"user:read", "repo:read"}; !equalStrings(got, want) {
76 t.Errorf("scopes: got %v, want %v", got, want)
77 }
78 }
79
80 func TestCreate_RejectsUnknownClient(t *testing.T) {
81 pool := dbtest.NewTestDB(t)
82 _, err := devicecode.Create(context.Background(), devicecode.Deps{Pool: pool}, defaultsForTest(), "evil-cli", "")
83 if !errors.Is(err, devicecode.ErrUnauthorizedClient) {
84 t.Fatalf("got err %v, want ErrUnauthorizedClient", err)
85 }
86 }
87
88 func TestCreate_RejectsUnknownScope(t *testing.T) {
89 pool := dbtest.NewTestDB(t)
90 _, err := devicecode.Create(context.Background(), devicecode.Deps{Pool: pool}, defaultsForTest(), "shithub-cli", "user:read,bogus:scope")
91 if !errors.Is(err, devicecode.ErrInvalidScope) {
92 t.Fatalf("got err %v, want ErrInvalidScope", err)
93 }
94 }
95
96 func TestExchange_PendingThenApprovedThenOneShot(t *testing.T) {
97 pool := dbtest.NewTestDB(t)
98 deps := devicecode.Deps{Pool: pool}
99 userID := seedUser(t, pool, "alice")
100
101 auth, err := devicecode.Create(context.Background(), deps, defaultsForTest(), "shithub-cli", "user:read")
102 if err != nil {
103 t.Fatalf("Create: %v", err)
104 }
105
106 // Pending: Exchange returns authorization_pending.
107 if _, err := devicecode.Exchange(context.Background(), deps, "shithub-cli", auth.DeviceCode, "test"); !errors.Is(err, devicecode.ErrAuthorizationPending) {
108 t.Fatalf("pending: got %v, want ErrAuthorizationPending", err)
109 }
110
111 row, err := devicecode.LookupByUserCode(context.Background(), deps, auth.UserCode)
112 if err != nil {
113 t.Fatalf("LookupByUserCode: %v", err)
114 }
115 if err := devicecode.Approve(context.Background(), deps, row.ID, userID); err != nil {
116 t.Fatalf("Approve: %v", err)
117 }
118
119 // Advance last_polled_at so the slow_down gate doesn't fire on the
120 // second Exchange below; otherwise we'd race the 5-second window.
121 if _, err := pool.Exec(context.Background(),
122 "UPDATE device_authorizations SET last_polled_at = now() - interval '10 seconds' WHERE id = $1",
123 row.ID); err != nil {
124 t.Fatalf("rewind last_polled_at: %v", err)
125 }
126
127 res, err := devicecode.Exchange(context.Background(), deps, "shithub-cli", auth.DeviceCode, "test")
128 if err != nil {
129 t.Fatalf("Exchange after approve: %v", err)
130 }
131 if !strings.HasPrefix(res.AccessToken, "shithub_pat_") {
132 t.Errorf("access_token prefix: got %q", res.AccessToken)
133 }
134 if res.TokenType != "bearer" {
135 t.Errorf("token_type: got %q", res.TokenType)
136 }
137 if got, want := res.Scopes, []string{"user:read"}; !equalStrings(got, want) {
138 t.Errorf("scopes: got %v, want %v", got, want)
139 }
140
141 // One-shot lockout: a second Exchange must NOT re-issue.
142 if _, err := pool.Exec(context.Background(),
143 "UPDATE device_authorizations SET last_polled_at = now() - interval '10 seconds' WHERE id = $1",
144 row.ID); err != nil {
145 t.Fatalf("rewind last_polled_at: %v", err)
146 }
147 if _, err := devicecode.Exchange(context.Background(), deps, "shithub-cli", auth.DeviceCode, "test"); !errors.Is(err, devicecode.ErrInvalidGrant) {
148 t.Fatalf("one-shot: got %v, want ErrInvalidGrant", err)
149 }
150 }
151
152 func TestExchange_DeniedReturnsAccessDenied(t *testing.T) {
153 pool := dbtest.NewTestDB(t)
154 deps := devicecode.Deps{Pool: pool}
155 auth, err := devicecode.Create(context.Background(), deps, defaultsForTest(), "shithub-cli", "")
156 if err != nil {
157 t.Fatalf("Create: %v", err)
158 }
159 row, _ := devicecode.LookupByUserCode(context.Background(), deps, auth.UserCode)
160 if err := devicecode.Deny(context.Background(), deps, row.ID); err != nil {
161 t.Fatalf("Deny: %v", err)
162 }
163 if _, err := devicecode.Exchange(context.Background(), deps, "shithub-cli", auth.DeviceCode, "test"); !errors.Is(err, devicecode.ErrAccessDenied) {
164 t.Fatalf("denied: got %v, want ErrAccessDenied", err)
165 }
166 }
167
168 func TestExchange_ExpiredReturnsExpiredToken(t *testing.T) {
169 pool := dbtest.NewTestDB(t)
170 deps := devicecode.Deps{Pool: pool}
171 auth, err := devicecode.Create(context.Background(), deps, defaultsForTest(), "shithub-cli", "")
172 if err != nil {
173 t.Fatalf("Create: %v", err)
174 }
175 // Backdate expires_at past now().
176 if _, err := pool.Exec(context.Background(),
177 "UPDATE device_authorizations SET expires_at = now() - interval '1 minute' WHERE user_code = $1",
178 auth.UserCode); err != nil {
179 t.Fatalf("backdate expires_at: %v", err)
180 }
181 if _, err := devicecode.Exchange(context.Background(), deps, "shithub-cli", auth.DeviceCode, "test"); !errors.Is(err, devicecode.ErrExpiredToken) {
182 t.Fatalf("expired: got %v, want ErrExpiredToken", err)
183 }
184 }
185
186 func TestExchange_SlowDownAfterFastPoll(t *testing.T) {
187 pool := dbtest.NewTestDB(t)
188 deps := devicecode.Deps{Pool: pool}
189 auth, err := devicecode.Create(context.Background(), deps, defaultsForTest(), "shithub-cli", "")
190 if err != nil {
191 t.Fatalf("Create: %v", err)
192 }
193 // First poll lands; expect authorization_pending and a stamped last_polled_at.
194 if _, err := devicecode.Exchange(context.Background(), deps, "shithub-cli", auth.DeviceCode, "test"); !errors.Is(err, devicecode.ErrAuthorizationPending) {
195 t.Fatalf("first poll: got %v, want ErrAuthorizationPending", err)
196 }
197 // Immediate re-poll inside the interval → slow_down.
198 if _, err := devicecode.Exchange(context.Background(), deps, "shithub-cli", auth.DeviceCode, "test"); !errors.Is(err, devicecode.ErrSlowDown) {
199 t.Fatalf("second poll: got %v, want ErrSlowDown", err)
200 }
201 }
202
203 func TestExchange_WrongClientIDReturnsUnauthorizedClient(t *testing.T) {
204 pool := dbtest.NewTestDB(t)
205 deps := devicecode.Deps{Pool: pool}
206 cfg := defaultsForTest()
207 cfg.ClientIDs = []string{"shithub-cli", "other-cli"}
208 auth, err := devicecode.Create(context.Background(), deps, cfg, "shithub-cli", "")
209 if err != nil {
210 t.Fatalf("Create: %v", err)
211 }
212 if _, err := devicecode.Exchange(context.Background(), deps, "other-cli", auth.DeviceCode, "test"); !errors.Is(err, devicecode.ErrUnauthorizedClient) {
213 t.Fatalf("got %v, want ErrUnauthorizedClient", err)
214 }
215 }
216
217 func TestApprove_RejectsAlreadyTerminal(t *testing.T) {
218 pool := dbtest.NewTestDB(t)
219 deps := devicecode.Deps{Pool: pool}
220 userID := seedUser(t, pool, "alice")
221 auth, _ := devicecode.Create(context.Background(), deps, defaultsForTest(), "shithub-cli", "")
222 row, _ := devicecode.LookupByUserCode(context.Background(), deps, auth.UserCode)
223 if err := devicecode.Approve(context.Background(), deps, row.ID, userID); err != nil {
224 t.Fatalf("first approve: %v", err)
225 }
226 // Manually mark approved_at via Approve being called against an
227 // already-approved row.
228 if err := devicecode.Approve(context.Background(), deps, row.ID, userID); !errors.Is(err, devicecode.ErrAlreadyTerminal) {
229 t.Fatalf("second approve: got %v, want ErrAlreadyTerminal", err)
230 }
231 }
232
233 // equalStrings is a tiny slice-equality helper to keep imports light.
234 func equalStrings(a, b []string) bool {
235 if len(a) != len(b) {
236 return false
237 }
238 for i := range a {
239 if a[i] != b[i] {
240 return false
241 }
242 }
243 return true
244 }
245
246 // Sanity: the orchestrator must persist a row that the lookup query
247 // can find.
248 func TestCreate_PersistsRow(t *testing.T) {
249 pool := dbtest.NewTestDB(t)
250 deps := devicecode.Deps{Pool: pool}
251 auth, err := devicecode.Create(context.Background(), deps, defaultsForTest(), "shithub-cli", "")
252 if err != nil {
253 t.Fatalf("Create: %v", err)
254 }
255 row, err := usersdb.New().GetDeviceAuthorizationByUserCode(context.Background(), pool, auth.UserCode)
256 if err != nil {
257 t.Fatalf("lookup: %v", err)
258 }
259 if row.ApprovedAt.Valid || row.DeniedAt.Valid {
260 t.Errorf("fresh row should be pending; got approved=%v denied=%v",
261 row.ApprovedAt, row.DeniedAt)
262 }
263 if !row.ExpiresAt.Valid || row.ExpiresAt.Time.Before(time.Now()) {
264 t.Errorf("expires_at not in the future: %+v", row.ExpiresAt)
265 }
266 if row.UserID.Valid {
267 t.Errorf("user_id should be unset before approval; got %v", row.UserID)
268 }
269 }
270
271 // Defensive: a pgx.Int8 zero-value should marshal cleanly when we
272 // Approve a row that uses it for IssuedTokenID. This is a sanity check
273 // that the Approve query accepts a "null" issued_token_id at SQL level
274 // (we set it later in Exchange).
275 func TestApprove_SetsUserIDOnly(t *testing.T) {
276 pool := dbtest.NewTestDB(t)
277 deps := devicecode.Deps{Pool: pool}
278 userID := seedUser(t, pool, "alice")
279 auth, _ := devicecode.Create(context.Background(), deps, defaultsForTest(), "shithub-cli", "")
280 row, _ := devicecode.LookupByUserCode(context.Background(), deps, auth.UserCode)
281 if err := devicecode.Approve(context.Background(), deps, row.ID, userID); err != nil {
282 t.Fatalf("Approve: %v", err)
283 }
284 got, err := usersdb.New().GetDeviceAuthorizationByUserCode(context.Background(), pool, auth.UserCode)
285 if err != nil {
286 t.Fatalf("lookup: %v", err)
287 }
288 if !got.ApprovedAt.Valid {
289 t.Errorf("approved_at not set")
290 }
291 if got.UserID != (pgtype.Int8{Int64: userID, Valid: true}) {
292 t.Errorf("user_id: got %+v, want %d", got.UserID, userID)
293 }
294 if got.IssuedTokenID.Valid {
295 t.Errorf("issued_token_id should be set only at Exchange time; got %+v", got.IssuedTokenID)
296 }
297 }
298