Go · 6717 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 package worker_test
4
5 import (
6 "context"
7 "encoding/json"
8 "errors"
9 "sync"
10 "sync/atomic"
11 "testing"
12 "time"
13
14 "github.com/jackc/pgx/v5/pgtype"
15
16 "github.com/tenseleyFlow/shithub/internal/testing/dbtest"
17 "github.com/tenseleyFlow/shithub/internal/worker"
18 workerdb "github.com/tenseleyFlow/shithub/internal/worker/sqlc"
19 )
20
21 // testKind is unique per test so handlers don't bleed across parallel
22 // tests sharing the worker package's runtime state.
23 const (
24 testKindHappy worker.Kind = "test:happy"
25 testKindRetry worker.Kind = "test:retry"
26 testKindPoison worker.Kind = "test:poison"
27 testKindFanIn50 worker.Kind = "test:fanin50"
28 )
29
30 // runUntil starts the pool in a goroutine and returns a stop func that
31 // cancels the context and waits for clean exit.
32 func runPool(t *testing.T, p *worker.Pool) (cancel func()) {
33 t.Helper()
34 ctx, c := context.WithCancel(context.Background())
35 done := make(chan struct{})
36 go func() {
37 _ = p.Run(ctx)
38 close(done)
39 }()
40 return func() {
41 c()
42 select {
43 case <-done:
44 case <-time.After(10 * time.Second):
45 t.Fatal("pool did not stop in 10s")
46 }
47 }
48 }
49
50 func TestPool_HappyPath(t *testing.T) {
51 t.Parallel()
52 pool := dbtest.NewTestDB(t)
53
54 var seen atomic.Int64
55 p := worker.NewPool(pool, worker.PoolConfig{Workers: 2, IdlePoll: 200 * time.Millisecond})
56 p.Register(testKindHappy, func(_ context.Context, _ json.RawMessage) error {
57 seen.Add(1)
58 return nil
59 })
60 stop := runPool(t, p)
61 defer stop()
62
63 id, err := worker.Enqueue(context.Background(), pool, testKindHappy, map[string]any{"x": 1}, worker.EnqueueOptions{})
64 if err != nil {
65 t.Fatalf("Enqueue: %v", err)
66 }
67 if err := worker.Notify(context.Background(), pool); err != nil {
68 t.Fatalf("Notify: %v", err)
69 }
70
71 waitFor(t, 5*time.Second, func() bool { return seen.Load() == 1 })
72
73 q := workerdb.New()
74 job, err := q.GetJob(context.Background(), pool, id)
75 if err != nil {
76 t.Fatalf("GetJob: %v", err)
77 }
78 if !job.CompletedAt.Valid {
79 t.Errorf("job %d: completed_at unset; last_error=%v", id, job.LastError.String)
80 }
81 }
82
83 func TestPool_RetryThenSucceed(t *testing.T) {
84 t.Parallel()
85 pool := dbtest.NewTestDB(t)
86
87 var attempts atomic.Int64
88 p := worker.NewPool(pool, worker.PoolConfig{Workers: 2, IdlePoll: 100 * time.Millisecond})
89 p.Register(testKindRetry, func(_ context.Context, _ json.RawMessage) error {
90 if attempts.Add(1) < 2 {
91 return errors.New("transient")
92 }
93 return nil
94 })
95 stop := runPool(t, p)
96 defer stop()
97
98 // Enqueue with run_at well in the past so reschedule fires immediately.
99 id, err := worker.Enqueue(context.Background(), pool, testKindRetry, map[string]any{}, worker.EnqueueOptions{MaxAttempts: 5})
100 if err != nil {
101 t.Fatalf("Enqueue: %v", err)
102 }
103 _ = worker.Notify(context.Background(), pool)
104
105 // Force the rescheduled run_at to "now" so we don't wait the full
106 // backoff. We do this by polling: after the first attempt fails,
107 // the row's run_at is base * 2 ≈ 60s. We bypass via direct UPDATE.
108 q := workerdb.New()
109 waitFor(t, 5*time.Second, func() bool { return attempts.Load() >= 1 })
110 if _, err := pool.Exec(context.Background(), `UPDATE jobs SET run_at = now() WHERE id = $1`, id); err != nil {
111 t.Fatalf("force run_at: %v", err)
112 }
113 _ = worker.Notify(context.Background(), pool)
114
115 waitFor(t, 5*time.Second, func() bool { return attempts.Load() >= 2 })
116 waitFor(t, 5*time.Second, func() bool {
117 j, err := q.GetJob(context.Background(), pool, id)
118 return err == nil && j.CompletedAt.Valid
119 })
120 }
121
122 func TestPool_PoisonGoesStraightToFailed(t *testing.T) {
123 t.Parallel()
124 pool := dbtest.NewTestDB(t)
125
126 var calls atomic.Int64
127 p := worker.NewPool(pool, worker.PoolConfig{Workers: 1, IdlePoll: 100 * time.Millisecond})
128 p.Register(testKindPoison, func(_ context.Context, _ json.RawMessage) error {
129 calls.Add(1)
130 return worker.PoisonError(errors.New("nope"))
131 })
132 stop := runPool(t, p)
133 defer stop()
134
135 id, err := worker.Enqueue(context.Background(), pool, testKindPoison, map[string]any{}, worker.EnqueueOptions{})
136 if err != nil {
137 t.Fatalf("Enqueue: %v", err)
138 }
139 _ = worker.Notify(context.Background(), pool)
140
141 q := workerdb.New()
142 waitFor(t, 5*time.Second, func() bool {
143 j, err := q.GetJob(context.Background(), pool, id)
144 return err == nil && j.FailedAt.Valid
145 })
146 if got := calls.Load(); got != 1 {
147 t.Errorf("calls = %d, want 1 (no retry on poison)", got)
148 }
149 j, _ := q.GetJob(context.Background(), pool, id)
150 if !j.LastError.Valid || j.LastError.String == "" {
151 t.Errorf("last_error not recorded on poison")
152 }
153 }
154
155 func TestPool_ConcurrentClaimsExactlyOnce(t *testing.T) {
156 t.Parallel()
157 pool := dbtest.NewTestDB(t)
158
159 const total = 50
160 processed := make(map[int64]int) // job_id → times processed
161 var mu sync.Mutex
162 p := worker.NewPool(pool, worker.PoolConfig{Workers: 4, IdlePoll: 50 * time.Millisecond})
163 p.Register(testKindFanIn50, func(_ context.Context, raw json.RawMessage) error {
164 var payload struct {
165 ID int64 `json:"id"`
166 }
167 _ = json.Unmarshal(raw, &payload)
168 mu.Lock()
169 processed[payload.ID]++
170 mu.Unlock()
171 return nil
172 })
173 stop := runPool(t, p)
174 defer stop()
175
176 for i := 0; i < total; i++ {
177 _, err := worker.Enqueue(context.Background(), pool, testKindFanIn50,
178 map[string]any{"id": i}, worker.EnqueueOptions{})
179 if err != nil {
180 t.Fatalf("Enqueue: %v", err)
181 }
182 }
183 _ = worker.Notify(context.Background(), pool)
184
185 waitFor(t, 10*time.Second, func() bool {
186 mu.Lock()
187 defer mu.Unlock()
188 return len(processed) == total
189 })
190
191 mu.Lock()
192 defer mu.Unlock()
193 for id, count := range processed {
194 if count != 1 {
195 t.Errorf("job %d processed %d times, want 1", id, count)
196 }
197 }
198 }
199
200 func TestEnqueue_DelayedRunAt(t *testing.T) {
201 t.Parallel()
202 pool := dbtest.NewTestDB(t)
203 future := time.Now().Add(1 * time.Hour)
204 id, err := worker.Enqueue(context.Background(), pool, "test:delayed", map[string]any{}, worker.EnqueueOptions{
205 RunAt: pgtype.Timestamptz{Time: future, Valid: true},
206 })
207 if err != nil {
208 t.Fatalf("Enqueue: %v", err)
209 }
210 q := workerdb.New()
211 job, err := q.GetJob(context.Background(), pool, id)
212 if err != nil {
213 t.Fatalf("GetJob: %v", err)
214 }
215 if !job.RunAt.Time.Equal(future.UTC().Truncate(time.Microsecond)) {
216 // pg truncates to microseconds; allow a tiny delta.
217 if d := job.RunAt.Time.Sub(future); d > time.Second || d < -time.Second {
218 t.Errorf("run_at = %v, want %v", job.RunAt.Time, future)
219 }
220 }
221 }
222
223 // waitFor polls cond every 50ms up to limit. Fails the test on timeout.
224 func waitFor(t *testing.T, limit time.Duration, cond func() bool) {
225 t.Helper()
226 deadline := time.Now().Add(limit)
227 for time.Now().Before(deadline) {
228 if cond() {
229 return
230 }
231 time.Sleep(50 * time.Millisecond)
232 }
233 t.Fatalf("waitFor: condition not met within %v", limit)
234 }
235