tenseleyflow/shithub / 5167aef

Browse files

Add Postgres-backed auth throttle with fixed-window counter + tests

Authored by mfwolffe <wolffemf@dukes.jmu.edu>
SHA
5167aef5d8276679fbf2af6da048e4308e1d7a33
Parents
21bd795
Tree
6bfa897

2 changed files

StatusFile+-
A internal/auth/throttle/throttle.go 106 0
A internal/auth/throttle/throttle_test.go 74 0
internal/auth/throttle/throttle.goadded
@@ -0,0 +1,106 @@
1
+// SPDX-License-Identifier: AGPL-3.0-or-later
2
+
3
+// Package throttle implements counter-based rate limiting for the auth
4
+// surface, backed by the auth_throttle table. The model is fixed-window:
5
+// each (scope, identifier) pair has a hit counter that resets when the
6
+// window has elapsed.
7
+//
8
+// We deliberately use Postgres rather than introducing Redis. At launch
9
+// scale this is well within Postgres's comfort zone, and avoiding a new
10
+// dependency is worth the marginal latency. Migrate if S36 proves it
11
+// necessary.
12
+package throttle
13
+
14
+import (
15
+	"context"
16
+	"errors"
17
+	"fmt"
18
+	"time"
19
+
20
+	"github.com/jackc/pgx/v5/pgtype"
21
+
22
+	usersdb "github.com/tenseleyFlow/shithub/internal/users/sqlc"
23
+)
24
+
25
+// DBTX matches sqlc's DBTX so callers can pass a *pgxpool.Pool or a
26
+// transaction interchangeably.
27
+type DBTX = usersdb.DBTX
28
+
29
+// Limiter holds the queries handle. Construct with NewLimiter.
30
+type Limiter struct {
31
+	q *usersdb.Queries
32
+}
33
+
34
+// NewLimiter returns a limiter bound to the sqlc queries package.
35
+func NewLimiter() *Limiter {
36
+	return &Limiter{q: usersdb.New()}
37
+}
38
+
39
+// Limit is a (scope, identifier, max-hits, window) policy.
40
+type Limit struct {
41
+	Scope      string        // e.g. "login", "signup", "reset"
42
+	Identifier string        // e.g. "ip:1.2.3.4|alice"
43
+	Max        int           // hits permitted within Window
44
+	Window     time.Duration // sliding-fixed window length
45
+}
46
+
47
+// ErrThrottled is returned by Hit when the policy's limit is exceeded.
48
+// RetryAfter carries a hint suitable for the HTTP `Retry-After` header.
49
+type ErrThrottled struct {
50
+	RetryAfter time.Duration
51
+	Hits       int
52
+}
53
+
54
+func (e *ErrThrottled) Error() string {
55
+	return fmt.Sprintf("throttle: limit exceeded (hits=%d, retry in %s)", e.Hits, e.RetryAfter)
56
+}
57
+
58
+// IsThrottled is the canonical predicate for distinguishing a throttle
59
+// rejection from a generic error.
60
+func IsThrottled(err error) bool {
61
+	var t *ErrThrottled
62
+	return errors.As(err, &t)
63
+}
64
+
65
+// Hit increments the counter for the policy. Returns nil if under the
66
+// limit, ErrThrottled (with RetryAfter) if at or over.
67
+//
68
+// The query is conditional: if the existing window started before
69
+// (now - policy.Window) we treat it as a new window and reset hits to 1.
70
+// Otherwise we increment in place. The work happens atomically inside
71
+// Postgres so concurrent requests don't undercount.
72
+func (l *Limiter) Hit(ctx context.Context, db DBTX, p Limit) error {
73
+	cutoff := pgtype.Timestamptz{Time: time.Now().Add(-p.Window), Valid: true}
74
+	row, err := l.q.BumpAuthThrottle(ctx, db, usersdb.BumpAuthThrottleParams{
75
+		Scope:           p.Scope,
76
+		Identifier:      p.Identifier,
77
+		WindowStartedAt: cutoff,
78
+	})
79
+	if err != nil {
80
+		return fmt.Errorf("throttle: bump: %w", err)
81
+	}
82
+	if int(row.Hits) > p.Max {
83
+		retry := p.Window - time.Since(row.WindowStartedAt.Time)
84
+		if retry < time.Second {
85
+			retry = time.Second
86
+		}
87
+		return &ErrThrottled{RetryAfter: retry, Hits: int(row.Hits)}
88
+	}
89
+	return nil
90
+}
91
+
92
+// Reset clears the counter for (scope, identifier). Used after a
93
+// successful login to forgive prior failed attempts on the same key.
94
+func (l *Limiter) Reset(ctx context.Context, db DBTX, scope, identifier string) error {
95
+	return l.q.ResetAuthThrottle(ctx, db, usersdb.ResetAuthThrottleParams{
96
+		Scope:      scope,
97
+		Identifier: identifier,
98
+	})
99
+}
100
+
101
+// Purge deletes throttle rows older than cutoff. Caller is expected to
102
+// run this on a periodic schedule (S14 worker).
103
+func (l *Limiter) Purge(ctx context.Context, db DBTX, olderThan time.Duration) error {
104
+	cutoff := pgtype.Timestamptz{Time: time.Now().Add(-olderThan), Valid: true}
105
+	return l.q.PurgeStaleAuthThrottle(ctx, db, cutoff)
106
+}
internal/auth/throttle/throttle_test.goadded
@@ -0,0 +1,74 @@
1
+// SPDX-License-Identifier: AGPL-3.0-or-later
2
+
3
+package throttle
4
+
5
+import (
6
+	"context"
7
+	"testing"
8
+	"time"
9
+
10
+	"github.com/tenseleyFlow/shithub/internal/testing/dbtest"
11
+)
12
+
13
+func TestLimiter_HitAndThrottle(t *testing.T) {
14
+	t.Parallel()
15
+	pool := dbtest.NewTestDB(t)
16
+	ctx := context.Background()
17
+	l := NewLimiter()
18
+
19
+	p := Limit{Scope: "login", Identifier: "ip:1.2.3.4|alice", Max: 3, Window: time.Hour}
20
+
21
+	for i := 1; i <= 3; i++ {
22
+		if err := l.Hit(ctx, pool, p); err != nil {
23
+			t.Fatalf("hit %d: %v", i, err)
24
+		}
25
+	}
26
+	err := l.Hit(ctx, pool, p)
27
+	if !IsThrottled(err) {
28
+		t.Fatalf("4th hit: expected throttled, got %v", err)
29
+	}
30
+}
31
+
32
+func TestLimiter_Reset(t *testing.T) {
33
+	t.Parallel()
34
+	pool := dbtest.NewTestDB(t)
35
+	ctx := context.Background()
36
+	l := NewLimiter()
37
+
38
+	p := Limit{Scope: "login", Identifier: "ip:1.2.3.4|bob", Max: 1, Window: time.Hour}
39
+
40
+	if err := l.Hit(ctx, pool, p); err != nil {
41
+		t.Fatalf("first hit: %v", err)
42
+	}
43
+	if err := l.Hit(ctx, pool, p); !IsThrottled(err) {
44
+		t.Fatalf("second hit before reset: expected throttled, got %v", err)
45
+	}
46
+	if err := l.Reset(ctx, pool, p.Scope, p.Identifier); err != nil {
47
+		t.Fatalf("reset: %v", err)
48
+	}
49
+	if err := l.Hit(ctx, pool, p); err != nil {
50
+		t.Fatalf("hit after reset: %v", err)
51
+	}
52
+}
53
+
54
+func TestLimiter_WindowReset(t *testing.T) {
55
+	t.Parallel()
56
+	pool := dbtest.NewTestDB(t)
57
+	ctx := context.Background()
58
+	l := NewLimiter()
59
+
60
+	// Window is short enough that the second hit lands in a brand-new
61
+	// window. The bump query resets the counter when the existing window
62
+	// started before (now - Window). Use a generous sleep so clock
63
+	// granularity / connection latency between the Go cutoff and the PG
64
+	// now() can't make the comparison ambiguous.
65
+	p := Limit{Scope: "login", Identifier: "ip:1.2.3.4|carol", Max: 1, Window: 200 * time.Millisecond}
66
+
67
+	if err := l.Hit(ctx, pool, p); err != nil {
68
+		t.Fatalf("first hit: %v", err)
69
+	}
70
+	time.Sleep(500 * time.Millisecond)
71
+	if err := l.Hit(ctx, pool, p); err != nil {
72
+		t.Fatalf("hit after window: expected fresh window, got %v", err)
73
+	}
74
+}