@@ -0,0 +1,411 @@ |
| 1 | +// SPDX-License-Identifier: AGPL-3.0-or-later |
| 2 | + |
| 3 | +// Package devicecode owns the RFC 8628 (OAuth 2.0 Device Authorization |
| 4 | +// Grant) state machine. The package is intentionally HTTP-shape-free: |
| 5 | +// it exposes orchestrators (Create / Approve / Deny / Exchange) that |
| 6 | +// HTTP handlers wrap in their thin request-decode / response-encode |
| 7 | +// layers. |
| 8 | +// |
| 9 | +// Design note on token disclosure: the raw PAT is minted at Exchange |
| 10 | +// time, not at Approve. This keeps the raw secret off the |
| 11 | +// device_authorizations row entirely — Approve only records consent |
| 12 | +// (user_id + approved_at + scopes are already there), and the first |
| 13 | +// successful Exchange materialises the PAT, stamps its id on |
| 14 | +// issued_token_id, and returns the raw value to the polling client. |
| 15 | +// Subsequent Exchange calls see issued_token_id is set and return |
| 16 | +// invalid_grant so the disclosure is exactly one-shot. |
| 17 | +package devicecode |
| 18 | + |
| 19 | +import ( |
| 20 | + "context" |
| 21 | + "crypto/rand" |
| 22 | + "crypto/sha256" |
| 23 | + "errors" |
| 24 | + "fmt" |
| 25 | + "strings" |
| 26 | + "time" |
| 27 | + |
| 28 | + "github.com/jackc/pgx/v5/pgtype" |
| 29 | + "github.com/jackc/pgx/v5/pgxpool" |
| 30 | + |
| 31 | + "github.com/tenseleyFlow/shithub/internal/auth/pat" |
| 32 | + usersdb "github.com/tenseleyFlow/shithub/internal/users/sqlc" |
| 33 | +) |
| 34 | + |
| 35 | +// Errors callers may need to distinguish. The HTTP layer maps these to |
| 36 | +// the RFC 8628 §3.5 error codes. |
| 37 | +var ( |
| 38 | + ErrUnauthorizedClient = errors.New("devicecode: unauthorized client") |
| 39 | + ErrInvalidScope = errors.New("devicecode: invalid scope") |
| 40 | + ErrInvalidGrant = errors.New("devicecode: invalid grant") |
| 41 | + ErrAuthorizationPending = errors.New("devicecode: authorization pending") |
| 42 | + ErrSlowDown = errors.New("devicecode: slow down") |
| 43 | + ErrAccessDenied = errors.New("devicecode: access denied") |
| 44 | + ErrExpiredToken = errors.New("devicecode: expired token") |
| 45 | + ErrAlreadyTerminal = errors.New("devicecode: already terminal") |
| 46 | +) |
| 47 | + |
| 48 | +// Config tunes the device-code grant. Zero values get RFC-shaped |
| 49 | +// defaults via effective() so handler wiring can pass through bare. |
| 50 | +type Config struct { |
| 51 | + // ClientIDs is the allowlist enforced on Create. An empty list |
| 52 | + // denies every request (deny-by-default). |
| 53 | + ClientIDs []string |
| 54 | + // DefaultScopes is applied when the request omits scope=. Mirrors |
| 55 | + // gh's behavior of granting a minimal read-only set by default. |
| 56 | + DefaultScopes []string |
| 57 | + // ExpiresIn is the grant lifetime. Clamped to 30 minutes max. |
| 58 | + ExpiresIn time.Duration |
| 59 | + // PollInterval is the advertised minimum polling cadence; the |
| 60 | + // Exchange path enforces slow_down against this. |
| 61 | + PollInterval time.Duration |
| 62 | +} |
| 63 | + |
| 64 | +// Defaults returns the canonical defaults: 15-minute grants, 5-second |
| 65 | +// poll interval, user:read scope, shithub-cli allowlist. |
| 66 | +func Defaults() Config { |
| 67 | + return Config{ |
| 68 | + ClientIDs: []string{"shithub-cli"}, |
| 69 | + DefaultScopes: []string{"user:read"}, |
| 70 | + ExpiresIn: 15 * time.Minute, |
| 71 | + PollInterval: 5 * time.Second, |
| 72 | + } |
| 73 | +} |
| 74 | + |
| 75 | +func (c Config) effective() Config { |
| 76 | + out := c |
| 77 | + if out.ExpiresIn <= 0 || out.ExpiresIn > 30*time.Minute { |
| 78 | + out.ExpiresIn = 15 * time.Minute |
| 79 | + } |
| 80 | + if out.PollInterval <= 0 || out.PollInterval > time.Minute { |
| 81 | + out.PollInterval = 5 * time.Second |
| 82 | + } |
| 83 | + if len(out.DefaultScopes) == 0 { |
| 84 | + out.DefaultScopes = []string{"user:read"} |
| 85 | + } |
| 86 | + return out |
| 87 | +} |
| 88 | + |
| 89 | +// Deps wires the package to the database. |
| 90 | +type Deps struct { |
| 91 | + Pool *pgxpool.Pool |
| 92 | +} |
| 93 | + |
| 94 | +// Authorization is the package-facing projection of an in-flight |
| 95 | +// device-code grant returned from Create. |
| 96 | +type Authorization struct { |
| 97 | + DeviceCode string // raw, returned to the client exactly once |
| 98 | + UserCode string // "ABCD-EFGH" shown to the user |
| 99 | + ExpiresIn time.Duration |
| 100 | + PollInterval time.Duration |
| 101 | + ClientID string |
| 102 | + Scopes []string |
| 103 | +} |
| 104 | + |
| 105 | +// ExchangeResult carries the outcome of a successful Exchange. |
| 106 | +type ExchangeResult struct { |
| 107 | + AccessToken string // raw PAT — returned to the client exactly once |
| 108 | + TokenType string // always "bearer" |
| 109 | + Scopes []string |
| 110 | +} |
| 111 | + |
| 112 | +// Create issues a fresh device authorization. The raw device_code is |
| 113 | +// stored only as its sha256 hash; the caller must propagate the |
| 114 | +// returned raw value to the client without persisting it server-side. |
| 115 | +func Create(ctx context.Context, deps Deps, cfg Config, clientID, scopeInput string) (Authorization, error) { |
| 116 | + c := cfg.effective() |
| 117 | + if !allowedClient(c.ClientIDs, clientID) { |
| 118 | + return Authorization{}, ErrUnauthorizedClient |
| 119 | + } |
| 120 | + scopes, err := parseScopes(scopeInput, c.DefaultScopes) |
| 121 | + if err != nil { |
| 122 | + return Authorization{}, err |
| 123 | + } |
| 124 | + |
| 125 | + deviceCodeRaw, deviceCodeHash, err := newDeviceCode() |
| 126 | + if err != nil { |
| 127 | + return Authorization{}, err |
| 128 | + } |
| 129 | + userCode, err := newUserCode() |
| 130 | + if err != nil { |
| 131 | + return Authorization{}, err |
| 132 | + } |
| 133 | + |
| 134 | + if _, err := usersdb.New().InsertDeviceAuthorization(ctx, deps.Pool, usersdb.InsertDeviceAuthorizationParams{ |
| 135 | + DeviceCodeHash: deviceCodeHash, |
| 136 | + UserCode: userCode, |
| 137 | + ClientID: clientID, |
| 138 | + Scopes: scopes, |
| 139 | + IntervalSeconds: int32(c.PollInterval / time.Second), |
| 140 | + ExpiresAt: pgtype.Timestamptz{Time: time.Now().Add(c.ExpiresIn), Valid: true}, |
| 141 | + }); err != nil { |
| 142 | + return Authorization{}, fmt.Errorf("devicecode: insert: %w", err) |
| 143 | + } |
| 144 | + |
| 145 | + return Authorization{ |
| 146 | + DeviceCode: deviceCodeRaw, |
| 147 | + UserCode: userCode, |
| 148 | + ExpiresIn: c.ExpiresIn, |
| 149 | + PollInterval: c.PollInterval, |
| 150 | + ClientID: clientID, |
| 151 | + Scopes: scopes, |
| 152 | + }, nil |
| 153 | +} |
| 154 | + |
| 155 | +// LookupByUserCode resolves the user-facing code to the underlying row |
| 156 | +// for the verification page. Returns ErrInvalidGrant on miss so the |
| 157 | +// HTML handler can render a uniform "code not recognised" message. |
| 158 | +func LookupByUserCode(ctx context.Context, deps Deps, userCode string) (usersdb.DeviceAuthorization, error) { |
| 159 | + row, err := usersdb.New().GetDeviceAuthorizationByUserCode(ctx, deps.Pool, normaliseUserCode(userCode)) |
| 160 | + if err != nil { |
| 161 | + return usersdb.DeviceAuthorization{}, ErrInvalidGrant |
| 162 | + } |
| 163 | + return row, nil |
| 164 | +} |
| 165 | + |
| 166 | +// Approve records the user's consent. The PAT is NOT minted here — see |
| 167 | +// the package doc comment for the reasoning. Approve returns the row's |
| 168 | +// id so the caller can hand it back to the HTML success page without |
| 169 | +// re-resolving by user_code. |
| 170 | +func Approve(ctx context.Context, deps Deps, rowID, userID int64) error { |
| 171 | + full, err := loadByID(ctx, deps.Pool, rowID) |
| 172 | + if err != nil { |
| 173 | + return ErrInvalidGrant |
| 174 | + } |
| 175 | + if full.ApprovedAt.Valid || full.DeniedAt.Valid { |
| 176 | + return ErrAlreadyTerminal |
| 177 | + } |
| 178 | + if time.Now().After(full.ExpiresAt.Time) { |
| 179 | + return ErrExpiredToken |
| 180 | + } |
| 181 | + return usersdb.New().ApproveDeviceAuthorization(ctx, deps.Pool, usersdb.ApproveDeviceAuthorizationParams{ |
| 182 | + ID: full.ID, |
| 183 | + UserID: pgtype.Int8{Int64: userID, Valid: true}, |
| 184 | + IssuedTokenID: pgtype.Int8{}, // populated by Exchange |
| 185 | + }) |
| 186 | +} |
| 187 | + |
| 188 | +// Deny terminates an in-flight authorization without minting a token. |
| 189 | +// Future Exchange polls return ErrAccessDenied. |
| 190 | +func Deny(ctx context.Context, deps Deps, rowID int64) error { |
| 191 | + full, err := loadByID(ctx, deps.Pool, rowID) |
| 192 | + if err != nil { |
| 193 | + return ErrInvalidGrant |
| 194 | + } |
| 195 | + if full.ApprovedAt.Valid || full.DeniedAt.Valid { |
| 196 | + return ErrAlreadyTerminal |
| 197 | + } |
| 198 | + return usersdb.New().DenyDeviceAuthorization(ctx, deps.Pool, full.ID) |
| 199 | +} |
| 200 | + |
| 201 | +// Exchange is the CLI-facing poll. Returns the minted PAT exactly |
| 202 | +// once: on the first successful exchange after approval. Subsequent |
| 203 | +// polls (after issued_token_id is set) return ErrInvalidGrant so the |
| 204 | +// raw token is disclosed at most once. |
| 205 | +// |
| 206 | +// The slow_down enforcement uses last_polled_at + interval_seconds. |
| 207 | +// The check fires BEFORE the approval check so a fast-polling client |
| 208 | +// gets a clear back-off signal instead of an accidental "pending". |
| 209 | +func Exchange(ctx context.Context, deps Deps, clientID, rawDeviceCode, tokenName string) (ExchangeResult, error) { |
| 210 | + hash, hashErr := hashDeviceCode(rawDeviceCode) |
| 211 | + if hashErr != nil { |
| 212 | + return ExchangeResult{}, ErrInvalidGrant |
| 213 | + } |
| 214 | + q := usersdb.New() |
| 215 | + row, dbErr := q.GetDeviceAuthorizationByCodeHash(ctx, deps.Pool, hash) |
| 216 | + if dbErr != nil { |
| 217 | + return ExchangeResult{}, ErrInvalidGrant |
| 218 | + } |
| 219 | + if row.ClientID != clientID { |
| 220 | + return ExchangeResult{}, ErrUnauthorizedClient |
| 221 | + } |
| 222 | + if time.Now().After(row.ExpiresAt.Time) { |
| 223 | + return ExchangeResult{}, ErrExpiredToken |
| 224 | + } |
| 225 | + if row.DeniedAt.Valid { |
| 226 | + return ExchangeResult{}, ErrAccessDenied |
| 227 | + } |
| 228 | + |
| 229 | + if row.LastPolledAt.Valid { |
| 230 | + minNext := row.LastPolledAt.Time.Add(time.Duration(row.IntervalSeconds) * time.Second) |
| 231 | + if time.Now().Before(minNext) { |
| 232 | + _ = q.TouchDeviceAuthorizationPoll(ctx, deps.Pool, row.ID) |
| 233 | + return ExchangeResult{}, ErrSlowDown |
| 234 | + } |
| 235 | + } |
| 236 | + _ = q.TouchDeviceAuthorizationPoll(ctx, deps.Pool, row.ID) |
| 237 | + |
| 238 | + if !row.ApprovedAt.Valid { |
| 239 | + return ExchangeResult{}, ErrAuthorizationPending |
| 240 | + } |
| 241 | + if row.IssuedTokenID.Valid { |
| 242 | + // One-shot disclosure already happened. The CLI either lost |
| 243 | + // the previous response or someone is replaying the grant. |
| 244 | + return ExchangeResult{}, ErrInvalidGrant |
| 245 | + } |
| 246 | + if !row.UserID.Valid { |
| 247 | + // Approved row without a user_id is a corrupted state — the |
| 248 | + // approval path always sets both atomically. Surface as |
| 249 | + // invalid_grant. |
| 250 | + return ExchangeResult{}, ErrInvalidGrant |
| 251 | + } |
| 252 | + |
| 253 | + raw, hashBytes, prefix, err := pat.Mint() |
| 254 | + if err != nil { |
| 255 | + return ExchangeResult{}, fmt.Errorf("devicecode: mint pat: %w", err) |
| 256 | + } |
| 257 | + tok, err := q.InsertUserToken(ctx, deps.Pool, usersdb.InsertUserTokenParams{ |
| 258 | + UserID: row.UserID.Int64, |
| 259 | + Name: tokenName, |
| 260 | + TokenHash: hashBytes, |
| 261 | + TokenPrefix: prefix, |
| 262 | + Scopes: row.Scopes, |
| 263 | + ExpiresAt: pgtype.Timestamptz{}, |
| 264 | + }) |
| 265 | + if err != nil { |
| 266 | + return ExchangeResult{}, fmt.Errorf("devicecode: insert token: %w", err) |
| 267 | + } |
| 268 | + |
| 269 | + // Stamp the issued token id back onto the row so the next |
| 270 | + // Exchange poll sees the one-shot lockout. Re-running Approve at |
| 271 | + // the SQL layer would clear approved_at; we use a dedicated raw |
| 272 | + // UPDATE here to keep the semantics tight. |
| 273 | + if _, err := deps.Pool.Exec(ctx, ` |
| 274 | + UPDATE device_authorizations SET issued_token_id = $2 WHERE id = $1 |
| 275 | + `, row.ID, tok.ID); err != nil { |
| 276 | + return ExchangeResult{}, fmt.Errorf("devicecode: stamp token: %w", err) |
| 277 | + } |
| 278 | + |
| 279 | + return ExchangeResult{ |
| 280 | + AccessToken: raw, |
| 281 | + TokenType: "bearer", |
| 282 | + Scopes: row.Scopes, |
| 283 | + }, nil |
| 284 | +} |
| 285 | + |
| 286 | +func allowedClient(allow []string, clientID string) bool { |
| 287 | + for _, v := range allow { |
| 288 | + if v == clientID { |
| 289 | + return true |
| 290 | + } |
| 291 | + } |
| 292 | + return false |
| 293 | +} |
| 294 | + |
| 295 | +// parseScopes accepts space- or comma-separated scopes and returns the |
| 296 | +// normalised, deduped, validated slice. Unknown scopes → ErrInvalidScope. |
| 297 | +func parseScopes(input string, defaults []string) ([]string, error) { |
| 298 | + input = strings.TrimSpace(input) |
| 299 | + if input == "" { |
| 300 | + out := make([]string, len(defaults)) |
| 301 | + copy(out, defaults) |
| 302 | + return out, nil |
| 303 | + } |
| 304 | + raw := strings.FieldsFunc(input, func(r rune) bool { return r == ',' || r == ' ' }) |
| 305 | + seen := make(map[string]struct{}, len(raw)) |
| 306 | + out := make([]string, 0, len(raw)) |
| 307 | + for _, s := range raw { |
| 308 | + s = strings.TrimSpace(s) |
| 309 | + if s == "" { |
| 310 | + continue |
| 311 | + } |
| 312 | + if !pat.ValidScope(s) { |
| 313 | + return nil, ErrInvalidScope |
| 314 | + } |
| 315 | + if _, dup := seen[s]; dup { |
| 316 | + continue |
| 317 | + } |
| 318 | + seen[s] = struct{}{} |
| 319 | + out = append(out, s) |
| 320 | + } |
| 321 | + if len(out) == 0 { |
| 322 | + out = append(out, defaults...) |
| 323 | + } |
| 324 | + return out, nil |
| 325 | +} |
| 326 | + |
| 327 | +func newDeviceCode() (raw string, hash []byte, err error) { |
| 328 | + buf := make([]byte, 32) |
| 329 | + if _, err := rand.Read(buf); err != nil { |
| 330 | + return "", nil, fmt.Errorf("devicecode: rand: %w", err) |
| 331 | + } |
| 332 | + raw = hexEncode(buf) |
| 333 | + sum := sha256.Sum256([]byte(raw)) |
| 334 | + return raw, sum[:], nil |
| 335 | +} |
| 336 | + |
| 337 | +func hashDeviceCode(raw string) ([]byte, error) { |
| 338 | + raw = strings.TrimSpace(raw) |
| 339 | + if raw == "" { |
| 340 | + return nil, ErrInvalidGrant |
| 341 | + } |
| 342 | + sum := sha256.Sum256([]byte(raw)) |
| 343 | + return sum[:], nil |
| 344 | +} |
| 345 | + |
| 346 | +// newUserCode produces an 8-character "ABCD-EFGH" identifier from a |
| 347 | +// 32-symbol alphabet that excludes 0/O/1/I to avoid typing ambiguity. |
| 348 | +func newUserCode() (string, error) { |
| 349 | + const alphabet = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789" |
| 350 | + buf := make([]byte, 8) |
| 351 | + if _, err := rand.Read(buf); err != nil { |
| 352 | + return "", err |
| 353 | + } |
| 354 | + out := make([]byte, 0, 9) |
| 355 | + for i, b := range buf { |
| 356 | + out = append(out, alphabet[int(b)%len(alphabet)]) |
| 357 | + if i == 3 { |
| 358 | + out = append(out, '-') |
| 359 | + } |
| 360 | + } |
| 361 | + return string(out), nil |
| 362 | +} |
| 363 | + |
| 364 | +func normaliseUserCode(s string) string { |
| 365 | + s = strings.ToUpper(strings.TrimSpace(s)) |
| 366 | + s = strings.ReplaceAll(s, " ", "") |
| 367 | + if len(s) == 8 && !strings.Contains(s, "-") { |
| 368 | + s = s[:4] + "-" + s[4:] |
| 369 | + } |
| 370 | + return s |
| 371 | +} |
| 372 | + |
| 373 | +func hexEncode(b []byte) string { |
| 374 | + const hexd = "0123456789abcdef" |
| 375 | + out := make([]byte, len(b)*2) |
| 376 | + for i, v := range b { |
| 377 | + out[i*2] = hexd[v>>4] |
| 378 | + out[i*2+1] = hexd[v&0x0f] |
| 379 | + } |
| 380 | + return string(out) |
| 381 | +} |
| 382 | + |
| 383 | +// loadByID resolves a row by its bigserial id. sqlc didn't generate a |
| 384 | +// dedicated GetByID because every other consumer comes in via |
| 385 | +// device_code_hash or user_code; the Approve / Deny path needs id |
| 386 | +// because it has it from the prior LookupByUserCode call. |
| 387 | +func loadByID(ctx context.Context, pool *pgxpool.Pool, id int64) (usersdb.DeviceAuthorization, error) { |
| 388 | + rows, err := pool.Query(ctx, ` |
| 389 | + SELECT id, device_code_hash, user_code, client_id, scopes, user_id, |
| 390 | + approved_at, denied_at, issued_token_id, interval_seconds, |
| 391 | + expires_at, last_polled_at, created_at |
| 392 | + FROM device_authorizations |
| 393 | + WHERE id = $1 |
| 394 | + `, id) |
| 395 | + if err != nil { |
| 396 | + return usersdb.DeviceAuthorization{}, err |
| 397 | + } |
| 398 | + defer rows.Close() |
| 399 | + if !rows.Next() { |
| 400 | + return usersdb.DeviceAuthorization{}, ErrInvalidGrant |
| 401 | + } |
| 402 | + var a usersdb.DeviceAuthorization |
| 403 | + if err := rows.Scan( |
| 404 | + &a.ID, &a.DeviceCodeHash, &a.UserCode, &a.ClientID, &a.Scopes, &a.UserID, |
| 405 | + &a.ApprovedAt, &a.DeniedAt, &a.IssuedTokenID, &a.IntervalSeconds, |
| 406 | + &a.ExpiresAt, &a.LastPolledAt, &a.CreatedAt, |
| 407 | + ); err != nil { |
| 408 | + return usersdb.DeviceAuthorization{}, err |
| 409 | + } |
| 410 | + return a, nil |
| 411 | +} |