tenseleyflow/shithub / 2857fe8

Browse files

auth/devicecode: orchestrator tests (11 scenarios)

Authored by mfwolffe <wolffemf@dukes.jmu.edu>
SHA
2857fe8278a80bc115a66cb367a93a9581be99c6
Parents
9c6f6f4
Tree
508204a

1 changed file

StatusFile+-
A internal/auth/devicecode/devicecode_test.go 297 0
internal/auth/devicecode/devicecode_test.goadded
@@ -0,0 +1,297 @@
1
+// SPDX-License-Identifier: AGPL-3.0-or-later
2
+
3
+package devicecode_test
4
+
5
+import (
6
+	"context"
7
+	"errors"
8
+	"strings"
9
+	"testing"
10
+	"time"
11
+
12
+	"github.com/jackc/pgx/v5/pgtype"
13
+	"github.com/jackc/pgx/v5/pgxpool"
14
+
15
+	"github.com/tenseleyFlow/shithub/internal/auth/devicecode"
16
+	"github.com/tenseleyFlow/shithub/internal/testing/dbtest"
17
+	usersdb "github.com/tenseleyFlow/shithub/internal/users/sqlc"
18
+)
19
+
20
+// fixedArgon2Hash is a pre-computed argon2id hash of an arbitrary password.
21
+// Using a constant avoids running argon2 in every test setup.
22
+const fixedArgon2Hash = "$argon2id$v=19$m=1024,t=1,p=1$YWFhYWFhYWFhYWFhYWFhYQ$" +
23
+	"DvBOTSnFhCBe+Pfx/W7Sk3hG3JCm2Wj0RBgCu+CPDtY"
24
+
25
+func seedUser(t *testing.T, pool *pgxpool.Pool, username string) int64 {
26
+	t.Helper()
27
+	q := usersdb.New()
28
+	u, err := q.CreateUser(context.Background(), pool, usersdb.CreateUserParams{
29
+		Username:     username,
30
+		DisplayName:  strings.ToUpper(username[:1]) + username[1:],
31
+		PasswordHash: fixedArgon2Hash,
32
+	})
33
+	if err != nil {
34
+		t.Fatalf("CreateUser: %v", err)
35
+	}
36
+	em, err := q.CreateUserEmail(context.Background(), pool, usersdb.CreateUserEmailParams{
37
+		UserID:    u.ID,
38
+		Email:     username + "@example.test",
39
+		IsPrimary: true,
40
+	})
41
+	if err != nil {
42
+		t.Fatalf("CreateUserEmail: %v", err)
43
+	}
44
+	if err := q.MarkUserEmailVerified(context.Background(), pool, em.ID); err != nil {
45
+		t.Fatalf("MarkUserEmailVerified: %v", err)
46
+	}
47
+	return u.ID
48
+}
49
+
50
+func defaultsForTest() devicecode.Config {
51
+	return devicecode.Config{
52
+		ClientIDs:     []string{"shithub-cli"},
53
+		DefaultScopes: []string{"user:read"},
54
+		ExpiresIn:     15 * time.Minute,
55
+		PollInterval:  5 * time.Second,
56
+	}
57
+}
58
+
59
+func TestCreate_HappyPath(t *testing.T) {
60
+	pool := dbtest.NewTestDB(t)
61
+	deps := devicecode.Deps{Pool: pool}
62
+	auth, err := devicecode.Create(context.Background(), deps, defaultsForTest(), "shithub-cli", "user:read,repo:read")
63
+	if err != nil {
64
+		t.Fatalf("Create: %v", err)
65
+	}
66
+	if auth.DeviceCode == "" || auth.UserCode == "" {
67
+		t.Fatalf("Create returned empty codes: %+v", auth)
68
+	}
69
+	if !strings.Contains(auth.UserCode, "-") || len(auth.UserCode) != 9 {
70
+		t.Errorf("user_code shape: got %q", auth.UserCode)
71
+	}
72
+	if auth.ExpiresIn != 15*time.Minute {
73
+		t.Errorf("expires_in: got %s", auth.ExpiresIn)
74
+	}
75
+	if got, want := auth.Scopes, []string{"user:read", "repo:read"}; !equalStrings(got, want) {
76
+		t.Errorf("scopes: got %v, want %v", got, want)
77
+	}
78
+}
79
+
80
+func TestCreate_RejectsUnknownClient(t *testing.T) {
81
+	pool := dbtest.NewTestDB(t)
82
+	_, err := devicecode.Create(context.Background(), devicecode.Deps{Pool: pool}, defaultsForTest(), "evil-cli", "")
83
+	if !errors.Is(err, devicecode.ErrUnauthorizedClient) {
84
+		t.Fatalf("got err %v, want ErrUnauthorizedClient", err)
85
+	}
86
+}
87
+
88
+func TestCreate_RejectsUnknownScope(t *testing.T) {
89
+	pool := dbtest.NewTestDB(t)
90
+	_, err := devicecode.Create(context.Background(), devicecode.Deps{Pool: pool}, defaultsForTest(), "shithub-cli", "user:read,bogus:scope")
91
+	if !errors.Is(err, devicecode.ErrInvalidScope) {
92
+		t.Fatalf("got err %v, want ErrInvalidScope", err)
93
+	}
94
+}
95
+
96
+func TestExchange_PendingThenApprovedThenOneShot(t *testing.T) {
97
+	pool := dbtest.NewTestDB(t)
98
+	deps := devicecode.Deps{Pool: pool}
99
+	userID := seedUser(t, pool, "alice")
100
+
101
+	auth, err := devicecode.Create(context.Background(), deps, defaultsForTest(), "shithub-cli", "user:read")
102
+	if err != nil {
103
+		t.Fatalf("Create: %v", err)
104
+	}
105
+
106
+	// Pending: Exchange returns authorization_pending.
107
+	if _, err := devicecode.Exchange(context.Background(), deps, "shithub-cli", auth.DeviceCode, "test"); !errors.Is(err, devicecode.ErrAuthorizationPending) {
108
+		t.Fatalf("pending: got %v, want ErrAuthorizationPending", err)
109
+	}
110
+
111
+	row, err := devicecode.LookupByUserCode(context.Background(), deps, auth.UserCode)
112
+	if err != nil {
113
+		t.Fatalf("LookupByUserCode: %v", err)
114
+	}
115
+	if err := devicecode.Approve(context.Background(), deps, row.ID, userID); err != nil {
116
+		t.Fatalf("Approve: %v", err)
117
+	}
118
+
119
+	// Advance last_polled_at so the slow_down gate doesn't fire on the
120
+	// second Exchange below; otherwise we'd race the 5-second window.
121
+	if _, err := pool.Exec(context.Background(),
122
+		"UPDATE device_authorizations SET last_polled_at = now() - interval '10 seconds' WHERE id = $1",
123
+		row.ID); err != nil {
124
+		t.Fatalf("rewind last_polled_at: %v", err)
125
+	}
126
+
127
+	res, err := devicecode.Exchange(context.Background(), deps, "shithub-cli", auth.DeviceCode, "test")
128
+	if err != nil {
129
+		t.Fatalf("Exchange after approve: %v", err)
130
+	}
131
+	if !strings.HasPrefix(res.AccessToken, "shithub_pat_") {
132
+		t.Errorf("access_token prefix: got %q", res.AccessToken)
133
+	}
134
+	if res.TokenType != "bearer" {
135
+		t.Errorf("token_type: got %q", res.TokenType)
136
+	}
137
+	if got, want := res.Scopes, []string{"user:read"}; !equalStrings(got, want) {
138
+		t.Errorf("scopes: got %v, want %v", got, want)
139
+	}
140
+
141
+	// One-shot lockout: a second Exchange must NOT re-issue.
142
+	if _, err := pool.Exec(context.Background(),
143
+		"UPDATE device_authorizations SET last_polled_at = now() - interval '10 seconds' WHERE id = $1",
144
+		row.ID); err != nil {
145
+		t.Fatalf("rewind last_polled_at: %v", err)
146
+	}
147
+	if _, err := devicecode.Exchange(context.Background(), deps, "shithub-cli", auth.DeviceCode, "test"); !errors.Is(err, devicecode.ErrInvalidGrant) {
148
+		t.Fatalf("one-shot: got %v, want ErrInvalidGrant", err)
149
+	}
150
+}
151
+
152
+func TestExchange_DeniedReturnsAccessDenied(t *testing.T) {
153
+	pool := dbtest.NewTestDB(t)
154
+	deps := devicecode.Deps{Pool: pool}
155
+	auth, err := devicecode.Create(context.Background(), deps, defaultsForTest(), "shithub-cli", "")
156
+	if err != nil {
157
+		t.Fatalf("Create: %v", err)
158
+	}
159
+	row, _ := devicecode.LookupByUserCode(context.Background(), deps, auth.UserCode)
160
+	if err := devicecode.Deny(context.Background(), deps, row.ID); err != nil {
161
+		t.Fatalf("Deny: %v", err)
162
+	}
163
+	if _, err := devicecode.Exchange(context.Background(), deps, "shithub-cli", auth.DeviceCode, "test"); !errors.Is(err, devicecode.ErrAccessDenied) {
164
+		t.Fatalf("denied: got %v, want ErrAccessDenied", err)
165
+	}
166
+}
167
+
168
+func TestExchange_ExpiredReturnsExpiredToken(t *testing.T) {
169
+	pool := dbtest.NewTestDB(t)
170
+	deps := devicecode.Deps{Pool: pool}
171
+	auth, err := devicecode.Create(context.Background(), deps, defaultsForTest(), "shithub-cli", "")
172
+	if err != nil {
173
+		t.Fatalf("Create: %v", err)
174
+	}
175
+	// Backdate expires_at past now().
176
+	if _, err := pool.Exec(context.Background(),
177
+		"UPDATE device_authorizations SET expires_at = now() - interval '1 minute' WHERE user_code = $1",
178
+		auth.UserCode); err != nil {
179
+		t.Fatalf("backdate expires_at: %v", err)
180
+	}
181
+	if _, err := devicecode.Exchange(context.Background(), deps, "shithub-cli", auth.DeviceCode, "test"); !errors.Is(err, devicecode.ErrExpiredToken) {
182
+		t.Fatalf("expired: got %v, want ErrExpiredToken", err)
183
+	}
184
+}
185
+
186
+func TestExchange_SlowDownAfterFastPoll(t *testing.T) {
187
+	pool := dbtest.NewTestDB(t)
188
+	deps := devicecode.Deps{Pool: pool}
189
+	auth, err := devicecode.Create(context.Background(), deps, defaultsForTest(), "shithub-cli", "")
190
+	if err != nil {
191
+		t.Fatalf("Create: %v", err)
192
+	}
193
+	// First poll lands; expect authorization_pending and a stamped last_polled_at.
194
+	if _, err := devicecode.Exchange(context.Background(), deps, "shithub-cli", auth.DeviceCode, "test"); !errors.Is(err, devicecode.ErrAuthorizationPending) {
195
+		t.Fatalf("first poll: got %v, want ErrAuthorizationPending", err)
196
+	}
197
+	// Immediate re-poll inside the interval → slow_down.
198
+	if _, err := devicecode.Exchange(context.Background(), deps, "shithub-cli", auth.DeviceCode, "test"); !errors.Is(err, devicecode.ErrSlowDown) {
199
+		t.Fatalf("second poll: got %v, want ErrSlowDown", err)
200
+	}
201
+}
202
+
203
+func TestExchange_WrongClientIDReturnsUnauthorizedClient(t *testing.T) {
204
+	pool := dbtest.NewTestDB(t)
205
+	deps := devicecode.Deps{Pool: pool}
206
+	cfg := defaultsForTest()
207
+	cfg.ClientIDs = []string{"shithub-cli", "other-cli"}
208
+	auth, err := devicecode.Create(context.Background(), deps, cfg, "shithub-cli", "")
209
+	if err != nil {
210
+		t.Fatalf("Create: %v", err)
211
+	}
212
+	if _, err := devicecode.Exchange(context.Background(), deps, "other-cli", auth.DeviceCode, "test"); !errors.Is(err, devicecode.ErrUnauthorizedClient) {
213
+		t.Fatalf("got %v, want ErrUnauthorizedClient", err)
214
+	}
215
+}
216
+
217
+func TestApprove_RejectsAlreadyTerminal(t *testing.T) {
218
+	pool := dbtest.NewTestDB(t)
219
+	deps := devicecode.Deps{Pool: pool}
220
+	userID := seedUser(t, pool, "alice")
221
+	auth, _ := devicecode.Create(context.Background(), deps, defaultsForTest(), "shithub-cli", "")
222
+	row, _ := devicecode.LookupByUserCode(context.Background(), deps, auth.UserCode)
223
+	if err := devicecode.Approve(context.Background(), deps, row.ID, userID); err != nil {
224
+		t.Fatalf("first approve: %v", err)
225
+	}
226
+	// Manually mark approved_at via Approve being called against an
227
+	// already-approved row.
228
+	if err := devicecode.Approve(context.Background(), deps, row.ID, userID); !errors.Is(err, devicecode.ErrAlreadyTerminal) {
229
+		t.Fatalf("second approve: got %v, want ErrAlreadyTerminal", err)
230
+	}
231
+}
232
+
233
+// equalStrings is a tiny slice-equality helper to keep imports light.
234
+func equalStrings(a, b []string) bool {
235
+	if len(a) != len(b) {
236
+		return false
237
+	}
238
+	for i := range a {
239
+		if a[i] != b[i] {
240
+			return false
241
+		}
242
+	}
243
+	return true
244
+}
245
+
246
+// Sanity: the orchestrator must persist a row that the lookup query
247
+// can find.
248
+func TestCreate_PersistsRow(t *testing.T) {
249
+	pool := dbtest.NewTestDB(t)
250
+	deps := devicecode.Deps{Pool: pool}
251
+	auth, err := devicecode.Create(context.Background(), deps, defaultsForTest(), "shithub-cli", "")
252
+	if err != nil {
253
+		t.Fatalf("Create: %v", err)
254
+	}
255
+	row, err := usersdb.New().GetDeviceAuthorizationByUserCode(context.Background(), pool, auth.UserCode)
256
+	if err != nil {
257
+		t.Fatalf("lookup: %v", err)
258
+	}
259
+	if row.ApprovedAt.Valid || row.DeniedAt.Valid {
260
+		t.Errorf("fresh row should be pending; got approved=%v denied=%v",
261
+			row.ApprovedAt, row.DeniedAt)
262
+	}
263
+	if !row.ExpiresAt.Valid || row.ExpiresAt.Time.Before(time.Now()) {
264
+		t.Errorf("expires_at not in the future: %+v", row.ExpiresAt)
265
+	}
266
+	if row.UserID.Valid {
267
+		t.Errorf("user_id should be unset before approval; got %v", row.UserID)
268
+	}
269
+}
270
+
271
+// Defensive: a pgx.Int8 zero-value should marshal cleanly when we
272
+// Approve a row that uses it for IssuedTokenID. This is a sanity check
273
+// that the Approve query accepts a "null" issued_token_id at SQL level
274
+// (we set it later in Exchange).
275
+func TestApprove_SetsUserIDOnly(t *testing.T) {
276
+	pool := dbtest.NewTestDB(t)
277
+	deps := devicecode.Deps{Pool: pool}
278
+	userID := seedUser(t, pool, "alice")
279
+	auth, _ := devicecode.Create(context.Background(), deps, defaultsForTest(), "shithub-cli", "")
280
+	row, _ := devicecode.LookupByUserCode(context.Background(), deps, auth.UserCode)
281
+	if err := devicecode.Approve(context.Background(), deps, row.ID, userID); err != nil {
282
+		t.Fatalf("Approve: %v", err)
283
+	}
284
+	got, err := usersdb.New().GetDeviceAuthorizationByUserCode(context.Background(), pool, auth.UserCode)
285
+	if err != nil {
286
+		t.Fatalf("lookup: %v", err)
287
+	}
288
+	if !got.ApprovedAt.Valid {
289
+		t.Errorf("approved_at not set")
290
+	}
291
+	if got.UserID != (pgtype.Int8{Int64: userID, Valid: true}) {
292
+		t.Errorf("user_id: got %+v, want %d", got.UserID, userID)
293
+	}
294
+	if got.IssuedTokenID.Valid {
295
+		t.Errorf("issued_token_id should be set only at Exchange time; got %+v", got.IssuedTokenID)
296
+	}
297
+}