Go · 8770 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 package worker
4
5 import (
6 "context"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "log/slog"
11 mrand "math/rand"
12 "os"
13 "strconv"
14 "sync"
15 "time"
16
17 "github.com/jackc/pgx/v5"
18 "github.com/jackc/pgx/v5/pgtype"
19 "github.com/jackc/pgx/v5/pgxpool"
20
21 "github.com/tenseleyFlow/shithub/internal/infra/metrics"
22 workerdb "github.com/tenseleyFlow/shithub/internal/worker/sqlc"
23 )
24
25 // PoolConfig configures Pool. Leave fields zero for the documented
26 // defaults.
27 type PoolConfig struct {
28 Workers int // default 4
29 IdlePoll time.Duration // default 5s — backstop when LISTEN drops a wake
30 JobTimeout time.Duration // default 5min, applied per-job via context
31 InstanceID string // default "<hostname>:<pid>"
32 Logger *slog.Logger // default discards
33 }
34
35 // Pool dispatches jobs from the queue. Construct via NewPool, register
36 // handlers via Register, run via Run.
37 type Pool struct {
38 cfg PoolConfig
39 db *pgxpool.Pool
40 q *workerdb.Queries
41 handlers map[Kind]Handler
42 rng *mrand.Rand
43 mu sync.Mutex // guards handlers + rng
44 }
45
46 // NewPool wires a pool against an open pgx pool. Callers register
47 // handlers before calling Run.
48 func NewPool(db *pgxpool.Pool, cfg PoolConfig) *Pool {
49 if cfg.Workers <= 0 {
50 cfg.Workers = 4
51 }
52 if cfg.IdlePoll <= 0 {
53 cfg.IdlePoll = 5 * time.Second
54 }
55 if cfg.JobTimeout <= 0 {
56 cfg.JobTimeout = 5 * time.Minute
57 }
58 if cfg.InstanceID == "" {
59 host, _ := os.Hostname()
60 cfg.InstanceID = host + ":" + strconv.Itoa(os.Getpid())
61 }
62 if cfg.Logger == nil {
63 cfg.Logger = slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelInfo}))
64 }
65 return &Pool{
66 cfg: cfg,
67 db: db,
68 q: workerdb.New(),
69 handlers: make(map[Kind]Handler),
70 // nolint:gosec // G404: jitter is non-cryptographic by design.
71 rng: mrand.New(mrand.NewSource(time.Now().UnixNano())),
72 }
73 }
74
75 // Register associates a handler with a kind. Re-registering replaces
76 // the previous handler. Registration is goroutine-safe so test harnesses
77 // can swap handlers between runs.
78 func (p *Pool) Register(kind Kind, h Handler) {
79 p.mu.Lock()
80 defer p.mu.Unlock()
81 p.handlers[kind] = h
82 }
83
84 // Run blocks until ctx is cancelled. Spawns cfg.Workers worker goroutines
85 // plus one LISTEN goroutine that fans out wake-ups. Returns nil after a
86 // clean drain; returns ctx.Err() if drain timed out.
87 func (p *Pool) Run(ctx context.Context) error {
88 p.cfg.Logger.InfoContext(ctx, "worker: starting",
89 "workers", p.cfg.Workers,
90 "instance_id", p.cfg.InstanceID,
91 "kinds", p.kindList())
92
93 wake := make(chan struct{}, p.cfg.Workers)
94 var wg sync.WaitGroup
95
96 // LISTEN goroutine. Holds a dedicated conn for the lifetime of Run.
97 wg.Add(1)
98 go func() {
99 defer wg.Done()
100 p.listenLoop(ctx, wake)
101 }()
102
103 for i := 0; i < p.cfg.Workers; i++ {
104 wg.Add(1)
105 go func(id int) {
106 defer wg.Done()
107 p.workerLoop(ctx, id, wake)
108 }(i)
109 }
110
111 wg.Wait()
112 p.cfg.Logger.InfoContext(ctx, "worker: stopped")
113 return nil
114 }
115
116 func (p *Pool) kindList() []string {
117 p.mu.Lock()
118 defer p.mu.Unlock()
119 out := make([]string, 0, len(p.handlers))
120 for k := range p.handlers {
121 out = append(out, string(k))
122 }
123 return out
124 }
125
126 // listenLoop maintains a LISTEN on NotifyChannel. On each NOTIFY, fan
127 // out to wake. On reconnect-required errors, sleeps briefly and retries.
128 func (p *Pool) listenLoop(ctx context.Context, wake chan<- struct{}) {
129 for {
130 if err := ctx.Err(); err != nil {
131 return
132 }
133 if err := p.listenOnce(ctx, wake); err != nil && !errors.Is(err, context.Canceled) {
134 p.cfg.Logger.WarnContext(ctx, "worker: listen restart", "error", err)
135 select {
136 case <-ctx.Done():
137 return
138 case <-time.After(2 * time.Second):
139 }
140 }
141 }
142 }
143
144 func (p *Pool) listenOnce(ctx context.Context, wake chan<- struct{}) error {
145 conn, err := p.db.Acquire(ctx)
146 if err != nil {
147 return fmt.Errorf("acquire: %w", err)
148 }
149 defer conn.Release()
150
151 if _, err := conn.Exec(ctx, "LISTEN "+NotifyChannel); err != nil {
152 return fmt.Errorf("LISTEN: %w", err)
153 }
154 for {
155 _, err := conn.Conn().WaitForNotification(ctx)
156 if err != nil {
157 return err
158 }
159 // Fan out to as many workers as are idle. Non-blocking sends so
160 // we don't stall on a saturated pool.
161 for i := 0; i < p.cfg.Workers; i++ {
162 select {
163 case wake <- struct{}{}:
164 default:
165 }
166 }
167 }
168 }
169
170 func (p *Pool) workerLoop(ctx context.Context, id int, wake <-chan struct{}) {
171 logger := p.cfg.Logger.With("worker_id", id)
172 ticker := time.NewTicker(p.cfg.IdlePoll)
173 defer ticker.Stop()
174
175 for {
176 select {
177 case <-ctx.Done():
178 return
179 case <-wake:
180 case <-ticker.C:
181 }
182 // Drain: try every registered kind; if any kind returned a job
183 // loop back immediately without waiting on wake/tick.
184 for {
185 any, err := p.tryClaimAndRun(ctx, logger)
186 if err != nil {
187 logger.WarnContext(ctx, "worker: claim cycle error", "error", err)
188 break
189 }
190 if !any {
191 break
192 }
193 }
194 }
195 }
196
197 // tryClaimAndRun walks every registered kind and attempts to claim one
198 // job. Returns true if any kind produced work this pass.
199 func (p *Pool) tryClaimAndRun(ctx context.Context, logger *slog.Logger) (bool, error) {
200 p.mu.Lock()
201 kinds := make([]Kind, 0, len(p.handlers))
202 for k := range p.handlers {
203 kinds = append(kinds, k)
204 }
205 p.mu.Unlock()
206
207 any := false
208 for _, kind := range kinds {
209 if ctx.Err() != nil {
210 return any, ctx.Err()
211 }
212 ran, err := p.runOne(ctx, kind, logger)
213 if err != nil {
214 return any, err
215 }
216 if ran {
217 any = true
218 }
219 }
220 return any, nil
221 }
222
223 // runOne claims one job of the given kind, runs it, and records the
224 // outcome. Returns ran=true when a job was claimed (regardless of
225 // success), false when the queue had nothing for this kind.
226 func (p *Pool) runOne(ctx context.Context, kind Kind, logger *slog.Logger) (bool, error) {
227 job, err := p.q.ClaimJob(ctx, p.db, workerdb.ClaimJobParams{
228 Kind: string(kind),
229 LockedBy: pgtype.Text{String: p.cfg.InstanceID, Valid: true},
230 })
231 if errors.Is(err, pgx.ErrNoRows) {
232 return false, nil
233 }
234 if err != nil {
235 return false, fmt.Errorf("claim %s: %w", kind, err)
236 }
237
238 p.mu.Lock()
239 h, ok := p.handlers[Kind(job.Kind)]
240 p.mu.Unlock()
241 if !ok {
242 // Registered handler vanished between claim and dispatch; fail
243 // the job rather than block the queue. Should never happen in
244 // practice — Register is one-shot at boot.
245 _ = p.q.MarkJobFailed(ctx, p.db, workerdb.MarkJobFailedParams{
246 ID: job.ID,
247 LastError: pgtype.Text{String: "no handler registered", Valid: true},
248 })
249 return true, nil
250 }
251
252 jobCtx, cancel := context.WithTimeout(ctx, p.cfg.JobTimeout)
253 start := time.Now()
254 metrics.WorkerInFlight.WithLabelValues(job.Kind).Inc()
255 runErr := safeRun(jobCtx, h, job.Payload)
256 metrics.WorkerInFlight.WithLabelValues(job.Kind).Dec()
257 cancel()
258 metrics.WorkerJobDurationSeconds.WithLabelValues(job.Kind).Observe(time.Since(start).Seconds())
259
260 logger.InfoContext(
261 ctx, "worker: dispatched",
262 "job_id", job.ID,
263 "kind", job.Kind,
264 "attempt", job.Attempts,
265 "duration_ms", time.Since(start).Milliseconds(),
266 "ok", runErr == nil,
267 )
268
269 if runErr == nil {
270 if err := p.q.MarkJobCompleted(ctx, p.db, job.ID); err != nil {
271 logger.ErrorContext(ctx, "worker: mark completed", "job_id", job.ID, "error", err)
272 }
273 metrics.WorkerJobsProcessedTotal.WithLabelValues(job.Kind, "ok").Inc()
274 return true, nil
275 }
276
277 // Failure path. Poison errors skip retry.
278 if errors.Is(runErr, ErrPoison) {
279 _ = p.q.MarkJobFailed(ctx, p.db, workerdb.MarkJobFailedParams{
280 ID: job.ID,
281 LastError: pgtype.Text{String: runErr.Error(), Valid: true},
282 })
283 metrics.WorkerJobsProcessedTotal.WithLabelValues(job.Kind, "poison").Inc()
284 return true, nil
285 }
286
287 if int(job.Attempts) >= int(job.MaxAttempts) {
288 _ = p.q.MarkJobFailed(ctx, p.db, workerdb.MarkJobFailedParams{
289 ID: job.ID,
290 LastError: pgtype.Text{String: runErr.Error(), Valid: true},
291 })
292 metrics.WorkerJobsProcessedTotal.WithLabelValues(job.Kind, "failed").Inc()
293 return true, nil
294 }
295
296 p.mu.Lock()
297 delay := Backoff(int(job.Attempts), p.rng.Float64)
298 p.mu.Unlock()
299 _ = p.q.RescheduleJob(ctx, p.db, workerdb.RescheduleJobParams{
300 ID: job.ID,
301 LastError: pgtype.Text{String: runErr.Error(), Valid: true},
302 RunAt: pgtype.Timestamptz{Time: time.Now().Add(delay), Valid: true},
303 })
304 metrics.WorkerJobsProcessedTotal.WithLabelValues(job.Kind, "retry").Inc()
305 return true, nil
306 }
307
308 // safeRun wraps the handler in a recover so a panicking handler doesn't
309 // take the worker goroutine with it; the job is rescheduled like any
310 // other failure.
311 func safeRun(ctx context.Context, h Handler, payload json.RawMessage) (err error) {
312 defer func() {
313 if r := recover(); r != nil {
314 err = fmt.Errorf("worker: handler panic: %v", r)
315 }
316 }()
317 return h(ctx, payload)
318 }
319