tenseleyflow/shithub / 323f392

Browse files

auth/devicecode: RFC 8628 orchestrators (Create/Approve/Deny/Exchange)

Authored by mfwolffe <wolffemf@dukes.jmu.edu>
SHA
323f3922985d000443615256f89ce05fb1488dda
Parents
3baee64
Tree
0345610

1 changed file

StatusFile+-
A internal/auth/devicecode/devicecode.go 411 0
internal/auth/devicecode/devicecode.goadded
@@ -0,0 +1,411 @@
1
+// SPDX-License-Identifier: AGPL-3.0-or-later
2
+
3
+// Package devicecode owns the RFC 8628 (OAuth 2.0 Device Authorization
4
+// Grant) state machine. The package is intentionally HTTP-shape-free:
5
+// it exposes orchestrators (Create / Approve / Deny / Exchange) that
6
+// HTTP handlers wrap in their thin request-decode / response-encode
7
+// layers.
8
+//
9
+// Design note on token disclosure: the raw PAT is minted at Exchange
10
+// time, not at Approve. This keeps the raw secret off the
11
+// device_authorizations row entirely — Approve only records consent
12
+// (user_id + approved_at + scopes are already there), and the first
13
+// successful Exchange materialises the PAT, stamps its id on
14
+// issued_token_id, and returns the raw value to the polling client.
15
+// Subsequent Exchange calls see issued_token_id is set and return
16
+// invalid_grant so the disclosure is exactly one-shot.
17
+package devicecode
18
+
19
+import (
20
+	"context"
21
+	"crypto/rand"
22
+	"crypto/sha256"
23
+	"errors"
24
+	"fmt"
25
+	"strings"
26
+	"time"
27
+
28
+	"github.com/jackc/pgx/v5/pgtype"
29
+	"github.com/jackc/pgx/v5/pgxpool"
30
+
31
+	"github.com/tenseleyFlow/shithub/internal/auth/pat"
32
+	usersdb "github.com/tenseleyFlow/shithub/internal/users/sqlc"
33
+)
34
+
35
+// Errors callers may need to distinguish. The HTTP layer maps these to
36
+// the RFC 8628 §3.5 error codes.
37
+var (
38
+	ErrUnauthorizedClient   = errors.New("devicecode: unauthorized client")
39
+	ErrInvalidScope         = errors.New("devicecode: invalid scope")
40
+	ErrInvalidGrant         = errors.New("devicecode: invalid grant")
41
+	ErrAuthorizationPending = errors.New("devicecode: authorization pending")
42
+	ErrSlowDown             = errors.New("devicecode: slow down")
43
+	ErrAccessDenied         = errors.New("devicecode: access denied")
44
+	ErrExpiredToken         = errors.New("devicecode: expired token")
45
+	ErrAlreadyTerminal      = errors.New("devicecode: already terminal")
46
+)
47
+
48
+// Config tunes the device-code grant. Zero values get RFC-shaped
49
+// defaults via effective() so handler wiring can pass through bare.
50
+type Config struct {
51
+	// ClientIDs is the allowlist enforced on Create. An empty list
52
+	// denies every request (deny-by-default).
53
+	ClientIDs []string
54
+	// DefaultScopes is applied when the request omits scope=. Mirrors
55
+	// gh's behavior of granting a minimal read-only set by default.
56
+	DefaultScopes []string
57
+	// ExpiresIn is the grant lifetime. Clamped to 30 minutes max.
58
+	ExpiresIn time.Duration
59
+	// PollInterval is the advertised minimum polling cadence; the
60
+	// Exchange path enforces slow_down against this.
61
+	PollInterval time.Duration
62
+}
63
+
64
+// Defaults returns the canonical defaults: 15-minute grants, 5-second
65
+// poll interval, user:read scope, shithub-cli allowlist.
66
+func Defaults() Config {
67
+	return Config{
68
+		ClientIDs:     []string{"shithub-cli"},
69
+		DefaultScopes: []string{"user:read"},
70
+		ExpiresIn:     15 * time.Minute,
71
+		PollInterval:  5 * time.Second,
72
+	}
73
+}
74
+
75
+func (c Config) effective() Config {
76
+	out := c
77
+	if out.ExpiresIn <= 0 || out.ExpiresIn > 30*time.Minute {
78
+		out.ExpiresIn = 15 * time.Minute
79
+	}
80
+	if out.PollInterval <= 0 || out.PollInterval > time.Minute {
81
+		out.PollInterval = 5 * time.Second
82
+	}
83
+	if len(out.DefaultScopes) == 0 {
84
+		out.DefaultScopes = []string{"user:read"}
85
+	}
86
+	return out
87
+}
88
+
89
+// Deps wires the package to the database.
90
+type Deps struct {
91
+	Pool *pgxpool.Pool
92
+}
93
+
94
+// Authorization is the package-facing projection of an in-flight
95
+// device-code grant returned from Create.
96
+type Authorization struct {
97
+	DeviceCode   string // raw, returned to the client exactly once
98
+	UserCode     string // "ABCD-EFGH" shown to the user
99
+	ExpiresIn    time.Duration
100
+	PollInterval time.Duration
101
+	ClientID     string
102
+	Scopes       []string
103
+}
104
+
105
+// ExchangeResult carries the outcome of a successful Exchange.
106
+type ExchangeResult struct {
107
+	AccessToken string // raw PAT — returned to the client exactly once
108
+	TokenType   string // always "bearer"
109
+	Scopes      []string
110
+}
111
+
112
+// Create issues a fresh device authorization. The raw device_code is
113
+// stored only as its sha256 hash; the caller must propagate the
114
+// returned raw value to the client without persisting it server-side.
115
+func Create(ctx context.Context, deps Deps, cfg Config, clientID, scopeInput string) (Authorization, error) {
116
+	c := cfg.effective()
117
+	if !allowedClient(c.ClientIDs, clientID) {
118
+		return Authorization{}, ErrUnauthorizedClient
119
+	}
120
+	scopes, err := parseScopes(scopeInput, c.DefaultScopes)
121
+	if err != nil {
122
+		return Authorization{}, err
123
+	}
124
+
125
+	deviceCodeRaw, deviceCodeHash, err := newDeviceCode()
126
+	if err != nil {
127
+		return Authorization{}, err
128
+	}
129
+	userCode, err := newUserCode()
130
+	if err != nil {
131
+		return Authorization{}, err
132
+	}
133
+
134
+	if _, err := usersdb.New().InsertDeviceAuthorization(ctx, deps.Pool, usersdb.InsertDeviceAuthorizationParams{
135
+		DeviceCodeHash:  deviceCodeHash,
136
+		UserCode:        userCode,
137
+		ClientID:        clientID,
138
+		Scopes:          scopes,
139
+		IntervalSeconds: int32(c.PollInterval / time.Second),
140
+		ExpiresAt:       pgtype.Timestamptz{Time: time.Now().Add(c.ExpiresIn), Valid: true},
141
+	}); err != nil {
142
+		return Authorization{}, fmt.Errorf("devicecode: insert: %w", err)
143
+	}
144
+
145
+	return Authorization{
146
+		DeviceCode:   deviceCodeRaw,
147
+		UserCode:     userCode,
148
+		ExpiresIn:    c.ExpiresIn,
149
+		PollInterval: c.PollInterval,
150
+		ClientID:     clientID,
151
+		Scopes:       scopes,
152
+	}, nil
153
+}
154
+
155
+// LookupByUserCode resolves the user-facing code to the underlying row
156
+// for the verification page. Returns ErrInvalidGrant on miss so the
157
+// HTML handler can render a uniform "code not recognised" message.
158
+func LookupByUserCode(ctx context.Context, deps Deps, userCode string) (usersdb.DeviceAuthorization, error) {
159
+	row, err := usersdb.New().GetDeviceAuthorizationByUserCode(ctx, deps.Pool, normaliseUserCode(userCode))
160
+	if err != nil {
161
+		return usersdb.DeviceAuthorization{}, ErrInvalidGrant
162
+	}
163
+	return row, nil
164
+}
165
+
166
+// Approve records the user's consent. The PAT is NOT minted here — see
167
+// the package doc comment for the reasoning. Approve returns the row's
168
+// id so the caller can hand it back to the HTML success page without
169
+// re-resolving by user_code.
170
+func Approve(ctx context.Context, deps Deps, rowID, userID int64) error {
171
+	full, err := loadByID(ctx, deps.Pool, rowID)
172
+	if err != nil {
173
+		return ErrInvalidGrant
174
+	}
175
+	if full.ApprovedAt.Valid || full.DeniedAt.Valid {
176
+		return ErrAlreadyTerminal
177
+	}
178
+	if time.Now().After(full.ExpiresAt.Time) {
179
+		return ErrExpiredToken
180
+	}
181
+	return usersdb.New().ApproveDeviceAuthorization(ctx, deps.Pool, usersdb.ApproveDeviceAuthorizationParams{
182
+		ID:            full.ID,
183
+		UserID:        pgtype.Int8{Int64: userID, Valid: true},
184
+		IssuedTokenID: pgtype.Int8{}, // populated by Exchange
185
+	})
186
+}
187
+
188
+// Deny terminates an in-flight authorization without minting a token.
189
+// Future Exchange polls return ErrAccessDenied.
190
+func Deny(ctx context.Context, deps Deps, rowID int64) error {
191
+	full, err := loadByID(ctx, deps.Pool, rowID)
192
+	if err != nil {
193
+		return ErrInvalidGrant
194
+	}
195
+	if full.ApprovedAt.Valid || full.DeniedAt.Valid {
196
+		return ErrAlreadyTerminal
197
+	}
198
+	return usersdb.New().DenyDeviceAuthorization(ctx, deps.Pool, full.ID)
199
+}
200
+
201
+// Exchange is the CLI-facing poll. Returns the minted PAT exactly
202
+// once: on the first successful exchange after approval. Subsequent
203
+// polls (after issued_token_id is set) return ErrInvalidGrant so the
204
+// raw token is disclosed at most once.
205
+//
206
+// The slow_down enforcement uses last_polled_at + interval_seconds.
207
+// The check fires BEFORE the approval check so a fast-polling client
208
+// gets a clear back-off signal instead of an accidental "pending".
209
+func Exchange(ctx context.Context, deps Deps, clientID, rawDeviceCode, tokenName string) (ExchangeResult, error) {
210
+	hash, hashErr := hashDeviceCode(rawDeviceCode)
211
+	if hashErr != nil {
212
+		return ExchangeResult{}, ErrInvalidGrant
213
+	}
214
+	q := usersdb.New()
215
+	row, dbErr := q.GetDeviceAuthorizationByCodeHash(ctx, deps.Pool, hash)
216
+	if dbErr != nil {
217
+		return ExchangeResult{}, ErrInvalidGrant
218
+	}
219
+	if row.ClientID != clientID {
220
+		return ExchangeResult{}, ErrUnauthorizedClient
221
+	}
222
+	if time.Now().After(row.ExpiresAt.Time) {
223
+		return ExchangeResult{}, ErrExpiredToken
224
+	}
225
+	if row.DeniedAt.Valid {
226
+		return ExchangeResult{}, ErrAccessDenied
227
+	}
228
+
229
+	if row.LastPolledAt.Valid {
230
+		minNext := row.LastPolledAt.Time.Add(time.Duration(row.IntervalSeconds) * time.Second)
231
+		if time.Now().Before(minNext) {
232
+			_ = q.TouchDeviceAuthorizationPoll(ctx, deps.Pool, row.ID)
233
+			return ExchangeResult{}, ErrSlowDown
234
+		}
235
+	}
236
+	_ = q.TouchDeviceAuthorizationPoll(ctx, deps.Pool, row.ID)
237
+
238
+	if !row.ApprovedAt.Valid {
239
+		return ExchangeResult{}, ErrAuthorizationPending
240
+	}
241
+	if row.IssuedTokenID.Valid {
242
+		// One-shot disclosure already happened. The CLI either lost
243
+		// the previous response or someone is replaying the grant.
244
+		return ExchangeResult{}, ErrInvalidGrant
245
+	}
246
+	if !row.UserID.Valid {
247
+		// Approved row without a user_id is a corrupted state — the
248
+		// approval path always sets both atomically. Surface as
249
+		// invalid_grant.
250
+		return ExchangeResult{}, ErrInvalidGrant
251
+	}
252
+
253
+	raw, hashBytes, prefix, err := pat.Mint()
254
+	if err != nil {
255
+		return ExchangeResult{}, fmt.Errorf("devicecode: mint pat: %w", err)
256
+	}
257
+	tok, err := q.InsertUserToken(ctx, deps.Pool, usersdb.InsertUserTokenParams{
258
+		UserID:      row.UserID.Int64,
259
+		Name:        tokenName,
260
+		TokenHash:   hashBytes,
261
+		TokenPrefix: prefix,
262
+		Scopes:      row.Scopes,
263
+		ExpiresAt:   pgtype.Timestamptz{},
264
+	})
265
+	if err != nil {
266
+		return ExchangeResult{}, fmt.Errorf("devicecode: insert token: %w", err)
267
+	}
268
+
269
+	// Stamp the issued token id back onto the row so the next
270
+	// Exchange poll sees the one-shot lockout. Re-running Approve at
271
+	// the SQL layer would clear approved_at; we use a dedicated raw
272
+	// UPDATE here to keep the semantics tight.
273
+	if _, err := deps.Pool.Exec(ctx, `
274
+		UPDATE device_authorizations SET issued_token_id = $2 WHERE id = $1
275
+	`, row.ID, tok.ID); err != nil {
276
+		return ExchangeResult{}, fmt.Errorf("devicecode: stamp token: %w", err)
277
+	}
278
+
279
+	return ExchangeResult{
280
+		AccessToken: raw,
281
+		TokenType:   "bearer",
282
+		Scopes:      row.Scopes,
283
+	}, nil
284
+}
285
+
286
+func allowedClient(allow []string, clientID string) bool {
287
+	for _, v := range allow {
288
+		if v == clientID {
289
+			return true
290
+		}
291
+	}
292
+	return false
293
+}
294
+
295
+// parseScopes accepts space- or comma-separated scopes and returns the
296
+// normalised, deduped, validated slice. Unknown scopes → ErrInvalidScope.
297
+func parseScopes(input string, defaults []string) ([]string, error) {
298
+	input = strings.TrimSpace(input)
299
+	if input == "" {
300
+		out := make([]string, len(defaults))
301
+		copy(out, defaults)
302
+		return out, nil
303
+	}
304
+	raw := strings.FieldsFunc(input, func(r rune) bool { return r == ',' || r == ' ' })
305
+	seen := make(map[string]struct{}, len(raw))
306
+	out := make([]string, 0, len(raw))
307
+	for _, s := range raw {
308
+		s = strings.TrimSpace(s)
309
+		if s == "" {
310
+			continue
311
+		}
312
+		if !pat.ValidScope(s) {
313
+			return nil, ErrInvalidScope
314
+		}
315
+		if _, dup := seen[s]; dup {
316
+			continue
317
+		}
318
+		seen[s] = struct{}{}
319
+		out = append(out, s)
320
+	}
321
+	if len(out) == 0 {
322
+		out = append(out, defaults...)
323
+	}
324
+	return out, nil
325
+}
326
+
327
+func newDeviceCode() (raw string, hash []byte, err error) {
328
+	buf := make([]byte, 32)
329
+	if _, err := rand.Read(buf); err != nil {
330
+		return "", nil, fmt.Errorf("devicecode: rand: %w", err)
331
+	}
332
+	raw = hexEncode(buf)
333
+	sum := sha256.Sum256([]byte(raw))
334
+	return raw, sum[:], nil
335
+}
336
+
337
+func hashDeviceCode(raw string) ([]byte, error) {
338
+	raw = strings.TrimSpace(raw)
339
+	if raw == "" {
340
+		return nil, ErrInvalidGrant
341
+	}
342
+	sum := sha256.Sum256([]byte(raw))
343
+	return sum[:], nil
344
+}
345
+
346
+// newUserCode produces an 8-character "ABCD-EFGH" identifier from a
347
+// 32-symbol alphabet that excludes 0/O/1/I to avoid typing ambiguity.
348
+func newUserCode() (string, error) {
349
+	const alphabet = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789"
350
+	buf := make([]byte, 8)
351
+	if _, err := rand.Read(buf); err != nil {
352
+		return "", err
353
+	}
354
+	out := make([]byte, 0, 9)
355
+	for i, b := range buf {
356
+		out = append(out, alphabet[int(b)%len(alphabet)])
357
+		if i == 3 {
358
+			out = append(out, '-')
359
+		}
360
+	}
361
+	return string(out), nil
362
+}
363
+
364
+func normaliseUserCode(s string) string {
365
+	s = strings.ToUpper(strings.TrimSpace(s))
366
+	s = strings.ReplaceAll(s, " ", "")
367
+	if len(s) == 8 && !strings.Contains(s, "-") {
368
+		s = s[:4] + "-" + s[4:]
369
+	}
370
+	return s
371
+}
372
+
373
+func hexEncode(b []byte) string {
374
+	const hexd = "0123456789abcdef"
375
+	out := make([]byte, len(b)*2)
376
+	for i, v := range b {
377
+		out[i*2] = hexd[v>>4]
378
+		out[i*2+1] = hexd[v&0x0f]
379
+	}
380
+	return string(out)
381
+}
382
+
383
+// loadByID resolves a row by its bigserial id. sqlc didn't generate a
384
+// dedicated GetByID because every other consumer comes in via
385
+// device_code_hash or user_code; the Approve / Deny path needs id
386
+// because it has it from the prior LookupByUserCode call.
387
+func loadByID(ctx context.Context, pool *pgxpool.Pool, id int64) (usersdb.DeviceAuthorization, error) {
388
+	rows, err := pool.Query(ctx, `
389
+		SELECT id, device_code_hash, user_code, client_id, scopes, user_id,
390
+		       approved_at, denied_at, issued_token_id, interval_seconds,
391
+		       expires_at, last_polled_at, created_at
392
+		FROM device_authorizations
393
+		WHERE id = $1
394
+	`, id)
395
+	if err != nil {
396
+		return usersdb.DeviceAuthorization{}, err
397
+	}
398
+	defer rows.Close()
399
+	if !rows.Next() {
400
+		return usersdb.DeviceAuthorization{}, ErrInvalidGrant
401
+	}
402
+	var a usersdb.DeviceAuthorization
403
+	if err := rows.Scan(
404
+		&a.ID, &a.DeviceCodeHash, &a.UserCode, &a.ClientID, &a.Scopes, &a.UserID,
405
+		&a.ApprovedAt, &a.DeniedAt, &a.IssuedTokenID, &a.IntervalSeconds,
406
+		&a.ExpiresAt, &a.LastPolledAt, &a.CreatedAt,
407
+	); err != nil {
408
+		return usersdb.DeviceAuthorization{}, err
409
+	}
410
+	return a, nil
411
+}