| 1 | // SPDX-License-Identifier: AGPL-3.0-or-later |
| 2 | |
| 3 | // Package throttle implements counter-based rate limiting for the auth |
| 4 | // surface, backed by the auth_throttle table. The model is fixed-window: |
| 5 | // each (scope, identifier) pair has a hit counter that resets when the |
| 6 | // window has elapsed. |
| 7 | // |
| 8 | // We deliberately use Postgres rather than introducing Redis. At launch |
| 9 | // scale this is well within Postgres's comfort zone, and avoiding a new |
| 10 | // dependency is worth the marginal latency. Migrate if S36 proves it |
| 11 | // necessary. |
| 12 | package throttle |
| 13 | |
| 14 | import ( |
| 15 | "context" |
| 16 | "errors" |
| 17 | "fmt" |
| 18 | "time" |
| 19 | |
| 20 | "github.com/jackc/pgx/v5/pgtype" |
| 21 | |
| 22 | usersdb "github.com/tenseleyFlow/shithub/internal/users/sqlc" |
| 23 | ) |
| 24 | |
| 25 | // DBTX matches sqlc's DBTX so callers can pass a *pgxpool.Pool or a |
| 26 | // transaction interchangeably. |
| 27 | type DBTX = usersdb.DBTX |
| 28 | |
| 29 | // Limiter holds the queries handle. Construct with NewLimiter. |
| 30 | type Limiter struct { |
| 31 | q *usersdb.Queries |
| 32 | } |
| 33 | |
| 34 | // NewLimiter returns a limiter bound to the sqlc queries package. |
| 35 | func NewLimiter() *Limiter { |
| 36 | return &Limiter{q: usersdb.New()} |
| 37 | } |
| 38 | |
| 39 | // Limit is a (scope, identifier, max-hits, window) policy. |
| 40 | type Limit struct { |
| 41 | Scope string // e.g. "login", "signup", "reset" |
| 42 | Identifier string // e.g. "ip:1.2.3.4|alice" |
| 43 | Max int // hits permitted within Window |
| 44 | Window time.Duration // sliding-fixed window length |
| 45 | } |
| 46 | |
| 47 | // ErrThrottled is returned by Hit when the policy's limit is exceeded. |
| 48 | // RetryAfter carries a hint suitable for the HTTP `Retry-After` header. |
| 49 | type ErrThrottled struct { |
| 50 | RetryAfter time.Duration |
| 51 | Hits int |
| 52 | } |
| 53 | |
| 54 | func (e *ErrThrottled) Error() string { |
| 55 | return fmt.Sprintf("throttle: limit exceeded (hits=%d, retry in %s)", e.Hits, e.RetryAfter) |
| 56 | } |
| 57 | |
| 58 | // IsThrottled is the canonical predicate for distinguishing a throttle |
| 59 | // rejection from a generic error. |
| 60 | func IsThrottled(err error) bool { |
| 61 | var t *ErrThrottled |
| 62 | return errors.As(err, &t) |
| 63 | } |
| 64 | |
| 65 | // Hit increments the counter for the policy. Returns nil if under the |
| 66 | // limit, ErrThrottled (with RetryAfter) if at or over. |
| 67 | // |
| 68 | // The query is conditional: if the existing window started before |
| 69 | // (now - policy.Window) we treat it as a new window and reset hits to 1. |
| 70 | // Otherwise we increment in place. The work happens atomically inside |
| 71 | // Postgres so concurrent requests don't undercount. |
| 72 | func (l *Limiter) Hit(ctx context.Context, db DBTX, p Limit) error { |
| 73 | cutoff := pgtype.Timestamptz{Time: time.Now().Add(-p.Window), Valid: true} |
| 74 | row, err := l.q.BumpAuthThrottle(ctx, db, usersdb.BumpAuthThrottleParams{ |
| 75 | Scope: p.Scope, |
| 76 | Identifier: p.Identifier, |
| 77 | WindowStartedAt: cutoff, |
| 78 | }) |
| 79 | if err != nil { |
| 80 | return fmt.Errorf("throttle: bump: %w", err) |
| 81 | } |
| 82 | if int(row.Hits) > p.Max { |
| 83 | retry := p.Window - time.Since(row.WindowStartedAt.Time) |
| 84 | if retry < time.Second { |
| 85 | retry = time.Second |
| 86 | } |
| 87 | return &ErrThrottled{RetryAfter: retry, Hits: int(row.Hits)} |
| 88 | } |
| 89 | return nil |
| 90 | } |
| 91 | |
| 92 | // Reset clears the counter for (scope, identifier). Used after a |
| 93 | // successful login to forgive prior failed attempts on the same key. |
| 94 | func (l *Limiter) Reset(ctx context.Context, db DBTX, scope, identifier string) error { |
| 95 | return l.q.ResetAuthThrottle(ctx, db, usersdb.ResetAuthThrottleParams{ |
| 96 | Scope: scope, |
| 97 | Identifier: identifier, |
| 98 | }) |
| 99 | } |
| 100 | |
| 101 | // Purge deletes throttle rows older than cutoff. Caller is expected to |
| 102 | // run this on a periodic schedule (S14 worker). |
| 103 | func (l *Limiter) Purge(ctx context.Context, db DBTX, olderThan time.Duration) error { |
| 104 | cutoff := pgtype.Timestamptz{Time: time.Now().Add(-olderThan), Valid: true} |
| 105 | return l.q.PurgeStaleAuthThrottle(ctx, db, cutoff) |
| 106 | } |
| 107 |