@@ -0,0 +1,106 @@ |
| 1 | +// SPDX-License-Identifier: AGPL-3.0-or-later |
| 2 | + |
| 3 | +// Package throttle implements counter-based rate limiting for the auth |
| 4 | +// surface, backed by the auth_throttle table. The model is fixed-window: |
| 5 | +// each (scope, identifier) pair has a hit counter that resets when the |
| 6 | +// window has elapsed. |
| 7 | +// |
| 8 | +// We deliberately use Postgres rather than introducing Redis. At launch |
| 9 | +// scale this is well within Postgres's comfort zone, and avoiding a new |
| 10 | +// dependency is worth the marginal latency. Migrate if S36 proves it |
| 11 | +// necessary. |
| 12 | +package throttle |
| 13 | + |
| 14 | +import ( |
| 15 | + "context" |
| 16 | + "errors" |
| 17 | + "fmt" |
| 18 | + "time" |
| 19 | + |
| 20 | + "github.com/jackc/pgx/v5/pgtype" |
| 21 | + |
| 22 | + usersdb "github.com/tenseleyFlow/shithub/internal/users/sqlc" |
| 23 | +) |
| 24 | + |
| 25 | +// DBTX matches sqlc's DBTX so callers can pass a *pgxpool.Pool or a |
| 26 | +// transaction interchangeably. |
| 27 | +type DBTX = usersdb.DBTX |
| 28 | + |
| 29 | +// Limiter holds the queries handle. Construct with NewLimiter. |
| 30 | +type Limiter struct { |
| 31 | + q *usersdb.Queries |
| 32 | +} |
| 33 | + |
| 34 | +// NewLimiter returns a limiter bound to the sqlc queries package. |
| 35 | +func NewLimiter() *Limiter { |
| 36 | + return &Limiter{q: usersdb.New()} |
| 37 | +} |
| 38 | + |
| 39 | +// Limit is a (scope, identifier, max-hits, window) policy. |
| 40 | +type Limit struct { |
| 41 | + Scope string // e.g. "login", "signup", "reset" |
| 42 | + Identifier string // e.g. "ip:1.2.3.4|alice" |
| 43 | + Max int // hits permitted within Window |
| 44 | + Window time.Duration // sliding-fixed window length |
| 45 | +} |
| 46 | + |
| 47 | +// ErrThrottled is returned by Hit when the policy's limit is exceeded. |
| 48 | +// RetryAfter carries a hint suitable for the HTTP `Retry-After` header. |
| 49 | +type ErrThrottled struct { |
| 50 | + RetryAfter time.Duration |
| 51 | + Hits int |
| 52 | +} |
| 53 | + |
| 54 | +func (e *ErrThrottled) Error() string { |
| 55 | + return fmt.Sprintf("throttle: limit exceeded (hits=%d, retry in %s)", e.Hits, e.RetryAfter) |
| 56 | +} |
| 57 | + |
| 58 | +// IsThrottled is the canonical predicate for distinguishing a throttle |
| 59 | +// rejection from a generic error. |
| 60 | +func IsThrottled(err error) bool { |
| 61 | + var t *ErrThrottled |
| 62 | + return errors.As(err, &t) |
| 63 | +} |
| 64 | + |
| 65 | +// Hit increments the counter for the policy. Returns nil if under the |
| 66 | +// limit, ErrThrottled (with RetryAfter) if at or over. |
| 67 | +// |
| 68 | +// The query is conditional: if the existing window started before |
| 69 | +// (now - policy.Window) we treat it as a new window and reset hits to 1. |
| 70 | +// Otherwise we increment in place. The work happens atomically inside |
| 71 | +// Postgres so concurrent requests don't undercount. |
| 72 | +func (l *Limiter) Hit(ctx context.Context, db DBTX, p Limit) error { |
| 73 | + cutoff := pgtype.Timestamptz{Time: time.Now().Add(-p.Window), Valid: true} |
| 74 | + row, err := l.q.BumpAuthThrottle(ctx, db, usersdb.BumpAuthThrottleParams{ |
| 75 | + Scope: p.Scope, |
| 76 | + Identifier: p.Identifier, |
| 77 | + WindowStartedAt: cutoff, |
| 78 | + }) |
| 79 | + if err != nil { |
| 80 | + return fmt.Errorf("throttle: bump: %w", err) |
| 81 | + } |
| 82 | + if int(row.Hits) > p.Max { |
| 83 | + retry := p.Window - time.Since(row.WindowStartedAt.Time) |
| 84 | + if retry < time.Second { |
| 85 | + retry = time.Second |
| 86 | + } |
| 87 | + return &ErrThrottled{RetryAfter: retry, Hits: int(row.Hits)} |
| 88 | + } |
| 89 | + return nil |
| 90 | +} |
| 91 | + |
| 92 | +// Reset clears the counter for (scope, identifier). Used after a |
| 93 | +// successful login to forgive prior failed attempts on the same key. |
| 94 | +func (l *Limiter) Reset(ctx context.Context, db DBTX, scope, identifier string) error { |
| 95 | + return l.q.ResetAuthThrottle(ctx, db, usersdb.ResetAuthThrottleParams{ |
| 96 | + Scope: scope, |
| 97 | + Identifier: identifier, |
| 98 | + }) |
| 99 | +} |
| 100 | + |
| 101 | +// Purge deletes throttle rows older than cutoff. Caller is expected to |
| 102 | +// run this on a periodic schedule (S14 worker). |
| 103 | +func (l *Limiter) Purge(ctx context.Context, db DBTX, olderThan time.Duration) error { |
| 104 | + cutoff := pgtype.Timestamptz{Time: time.Now().Add(-olderThan), Valid: true} |
| 105 | + return l.q.PurgeStaleAuthThrottle(ctx, db, cutoff) |
| 106 | +} |