Go · 3599 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 // Package throttle implements counter-based rate limiting for the auth
4 // surface, backed by the auth_throttle table. The model is fixed-window:
5 // each (scope, identifier) pair has a hit counter that resets when the
6 // window has elapsed.
7 //
8 // We deliberately use Postgres rather than introducing Redis. At launch
9 // scale this is well within Postgres's comfort zone, and avoiding a new
10 // dependency is worth the marginal latency. Migrate if S36 proves it
11 // necessary.
12 package throttle
13
14 import (
15 "context"
16 "errors"
17 "fmt"
18 "time"
19
20 "github.com/jackc/pgx/v5/pgtype"
21
22 usersdb "github.com/tenseleyFlow/shithub/internal/users/sqlc"
23 )
24
25 // DBTX matches sqlc's DBTX so callers can pass a *pgxpool.Pool or a
26 // transaction interchangeably.
27 type DBTX = usersdb.DBTX
28
29 // Limiter holds the queries handle. Construct with NewLimiter.
30 type Limiter struct {
31 q *usersdb.Queries
32 }
33
34 // NewLimiter returns a limiter bound to the sqlc queries package.
35 func NewLimiter() *Limiter {
36 return &Limiter{q: usersdb.New()}
37 }
38
39 // Limit is a (scope, identifier, max-hits, window) policy.
40 type Limit struct {
41 Scope string // e.g. "login", "signup", "reset"
42 Identifier string // e.g. "ip:1.2.3.4|alice"
43 Max int // hits permitted within Window
44 Window time.Duration // sliding-fixed window length
45 }
46
47 // ErrThrottled is returned by Hit when the policy's limit is exceeded.
48 // RetryAfter carries a hint suitable for the HTTP `Retry-After` header.
49 type ErrThrottled struct {
50 RetryAfter time.Duration
51 Hits int
52 }
53
54 func (e *ErrThrottled) Error() string {
55 return fmt.Sprintf("throttle: limit exceeded (hits=%d, retry in %s)", e.Hits, e.RetryAfter)
56 }
57
58 // IsThrottled is the canonical predicate for distinguishing a throttle
59 // rejection from a generic error.
60 func IsThrottled(err error) bool {
61 var t *ErrThrottled
62 return errors.As(err, &t)
63 }
64
65 // Hit increments the counter for the policy. Returns nil if under the
66 // limit, ErrThrottled (with RetryAfter) if at or over.
67 //
68 // The query is conditional: if the existing window started before
69 // (now - policy.Window) we treat it as a new window and reset hits to 1.
70 // Otherwise we increment in place. The work happens atomically inside
71 // Postgres so concurrent requests don't undercount.
72 func (l *Limiter) Hit(ctx context.Context, db DBTX, p Limit) error {
73 cutoff := pgtype.Timestamptz{Time: time.Now().Add(-p.Window), Valid: true}
74 row, err := l.q.BumpAuthThrottle(ctx, db, usersdb.BumpAuthThrottleParams{
75 Scope: p.Scope,
76 Identifier: p.Identifier,
77 WindowStartedAt: cutoff,
78 })
79 if err != nil {
80 return fmt.Errorf("throttle: bump: %w", err)
81 }
82 if int(row.Hits) > p.Max {
83 retry := p.Window - time.Since(row.WindowStartedAt.Time)
84 if retry < time.Second {
85 retry = time.Second
86 }
87 return &ErrThrottled{RetryAfter: retry, Hits: int(row.Hits)}
88 }
89 return nil
90 }
91
92 // Reset clears the counter for (scope, identifier). Used after a
93 // successful login to forgive prior failed attempts on the same key.
94 func (l *Limiter) Reset(ctx context.Context, db DBTX, scope, identifier string) error {
95 return l.q.ResetAuthThrottle(ctx, db, usersdb.ResetAuthThrottleParams{
96 Scope: scope,
97 Identifier: identifier,
98 })
99 }
100
101 // Purge deletes throttle rows older than cutoff. Caller is expected to
102 // run this on a periodic schedule (S14 worker).
103 func (l *Limiter) Purge(ctx context.Context, db DBTX, olderThan time.Duration) error {
104 cutoff := pgtype.Timestamptz{Time: time.Now().Add(-olderThan), Valid: true}
105 return l.q.PurgeStaleAuthThrottle(ctx, db, cutoff)
106 }
107