| 1 | // SPDX-License-Identifier: AGPL-3.0-or-later |
| 2 | |
| 3 | package worker |
| 4 | |
| 5 | import ( |
| 6 | "context" |
| 7 | "encoding/json" |
| 8 | "errors" |
| 9 | "fmt" |
| 10 | "log/slog" |
| 11 | mrand "math/rand" |
| 12 | "os" |
| 13 | "strconv" |
| 14 | "sync" |
| 15 | "time" |
| 16 | |
| 17 | "github.com/jackc/pgx/v5" |
| 18 | "github.com/jackc/pgx/v5/pgtype" |
| 19 | "github.com/jackc/pgx/v5/pgxpool" |
| 20 | |
| 21 | "github.com/tenseleyFlow/shithub/internal/infra/metrics" |
| 22 | workerdb "github.com/tenseleyFlow/shithub/internal/worker/sqlc" |
| 23 | ) |
| 24 | |
| 25 | // PoolConfig configures Pool. Leave fields zero for the documented |
| 26 | // defaults. |
| 27 | type PoolConfig struct { |
| 28 | Workers int // default 4 |
| 29 | IdlePoll time.Duration // default 5s — backstop when LISTEN drops a wake |
| 30 | JobTimeout time.Duration // default 5min, applied per-job via context |
| 31 | InstanceID string // default "<hostname>:<pid>" |
| 32 | Logger *slog.Logger // default discards |
| 33 | } |
| 34 | |
| 35 | // Pool dispatches jobs from the queue. Construct via NewPool, register |
| 36 | // handlers via Register, run via Run. |
| 37 | type Pool struct { |
| 38 | cfg PoolConfig |
| 39 | db *pgxpool.Pool |
| 40 | q *workerdb.Queries |
| 41 | handlers map[Kind]Handler |
| 42 | rng *mrand.Rand |
| 43 | mu sync.Mutex // guards handlers + rng |
| 44 | } |
| 45 | |
| 46 | // NewPool wires a pool against an open pgx pool. Callers register |
| 47 | // handlers before calling Run. |
| 48 | func NewPool(db *pgxpool.Pool, cfg PoolConfig) *Pool { |
| 49 | if cfg.Workers <= 0 { |
| 50 | cfg.Workers = 4 |
| 51 | } |
| 52 | if cfg.IdlePoll <= 0 { |
| 53 | cfg.IdlePoll = 5 * time.Second |
| 54 | } |
| 55 | if cfg.JobTimeout <= 0 { |
| 56 | cfg.JobTimeout = 5 * time.Minute |
| 57 | } |
| 58 | if cfg.InstanceID == "" { |
| 59 | host, _ := os.Hostname() |
| 60 | cfg.InstanceID = host + ":" + strconv.Itoa(os.Getpid()) |
| 61 | } |
| 62 | if cfg.Logger == nil { |
| 63 | cfg.Logger = slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelInfo})) |
| 64 | } |
| 65 | return &Pool{ |
| 66 | cfg: cfg, |
| 67 | db: db, |
| 68 | q: workerdb.New(), |
| 69 | handlers: make(map[Kind]Handler), |
| 70 | // nolint:gosec // G404: jitter is non-cryptographic by design. |
| 71 | rng: mrand.New(mrand.NewSource(time.Now().UnixNano())), |
| 72 | } |
| 73 | } |
| 74 | |
| 75 | // Register associates a handler with a kind. Re-registering replaces |
| 76 | // the previous handler. Registration is goroutine-safe so test harnesses |
| 77 | // can swap handlers between runs. |
| 78 | func (p *Pool) Register(kind Kind, h Handler) { |
| 79 | p.mu.Lock() |
| 80 | defer p.mu.Unlock() |
| 81 | p.handlers[kind] = h |
| 82 | } |
| 83 | |
| 84 | // Run blocks until ctx is cancelled. Spawns cfg.Workers worker goroutines |
| 85 | // plus one LISTEN goroutine that fans out wake-ups. Returns nil after a |
| 86 | // clean drain; returns ctx.Err() if drain timed out. |
| 87 | func (p *Pool) Run(ctx context.Context) error { |
| 88 | p.cfg.Logger.InfoContext(ctx, "worker: starting", |
| 89 | "workers", p.cfg.Workers, |
| 90 | "instance_id", p.cfg.InstanceID, |
| 91 | "kinds", p.kindList()) |
| 92 | |
| 93 | wake := make(chan struct{}, p.cfg.Workers) |
| 94 | var wg sync.WaitGroup |
| 95 | |
| 96 | // LISTEN goroutine. Holds a dedicated conn for the lifetime of Run. |
| 97 | wg.Add(1) |
| 98 | go func() { |
| 99 | defer wg.Done() |
| 100 | p.listenLoop(ctx, wake) |
| 101 | }() |
| 102 | |
| 103 | for i := 0; i < p.cfg.Workers; i++ { |
| 104 | wg.Add(1) |
| 105 | go func(id int) { |
| 106 | defer wg.Done() |
| 107 | p.workerLoop(ctx, id, wake) |
| 108 | }(i) |
| 109 | } |
| 110 | |
| 111 | wg.Wait() |
| 112 | p.cfg.Logger.InfoContext(ctx, "worker: stopped") |
| 113 | return nil |
| 114 | } |
| 115 | |
| 116 | func (p *Pool) kindList() []string { |
| 117 | p.mu.Lock() |
| 118 | defer p.mu.Unlock() |
| 119 | out := make([]string, 0, len(p.handlers)) |
| 120 | for k := range p.handlers { |
| 121 | out = append(out, string(k)) |
| 122 | } |
| 123 | return out |
| 124 | } |
| 125 | |
| 126 | // listenLoop maintains a LISTEN on NotifyChannel. On each NOTIFY, fan |
| 127 | // out to wake. On reconnect-required errors, sleeps briefly and retries. |
| 128 | func (p *Pool) listenLoop(ctx context.Context, wake chan<- struct{}) { |
| 129 | for { |
| 130 | if err := ctx.Err(); err != nil { |
| 131 | return |
| 132 | } |
| 133 | if err := p.listenOnce(ctx, wake); err != nil && !errors.Is(err, context.Canceled) { |
| 134 | p.cfg.Logger.WarnContext(ctx, "worker: listen restart", "error", err) |
| 135 | select { |
| 136 | case <-ctx.Done(): |
| 137 | return |
| 138 | case <-time.After(2 * time.Second): |
| 139 | } |
| 140 | } |
| 141 | } |
| 142 | } |
| 143 | |
| 144 | func (p *Pool) listenOnce(ctx context.Context, wake chan<- struct{}) error { |
| 145 | conn, err := p.db.Acquire(ctx) |
| 146 | if err != nil { |
| 147 | return fmt.Errorf("acquire: %w", err) |
| 148 | } |
| 149 | defer conn.Release() |
| 150 | |
| 151 | if _, err := conn.Exec(ctx, "LISTEN "+NotifyChannel); err != nil { |
| 152 | return fmt.Errorf("LISTEN: %w", err) |
| 153 | } |
| 154 | for { |
| 155 | _, err := conn.Conn().WaitForNotification(ctx) |
| 156 | if err != nil { |
| 157 | return err |
| 158 | } |
| 159 | // Fan out to as many workers as are idle. Non-blocking sends so |
| 160 | // we don't stall on a saturated pool. |
| 161 | for i := 0; i < p.cfg.Workers; i++ { |
| 162 | select { |
| 163 | case wake <- struct{}{}: |
| 164 | default: |
| 165 | } |
| 166 | } |
| 167 | } |
| 168 | } |
| 169 | |
| 170 | func (p *Pool) workerLoop(ctx context.Context, id int, wake <-chan struct{}) { |
| 171 | logger := p.cfg.Logger.With("worker_id", id) |
| 172 | ticker := time.NewTicker(p.cfg.IdlePoll) |
| 173 | defer ticker.Stop() |
| 174 | |
| 175 | for { |
| 176 | select { |
| 177 | case <-ctx.Done(): |
| 178 | return |
| 179 | case <-wake: |
| 180 | case <-ticker.C: |
| 181 | } |
| 182 | // Drain: try every registered kind; if any kind returned a job |
| 183 | // loop back immediately without waiting on wake/tick. |
| 184 | for { |
| 185 | any, err := p.tryClaimAndRun(ctx, logger) |
| 186 | if err != nil { |
| 187 | logger.WarnContext(ctx, "worker: claim cycle error", "error", err) |
| 188 | break |
| 189 | } |
| 190 | if !any { |
| 191 | break |
| 192 | } |
| 193 | } |
| 194 | } |
| 195 | } |
| 196 | |
| 197 | // tryClaimAndRun walks every registered kind and attempts to claim one |
| 198 | // job. Returns true if any kind produced work this pass. |
| 199 | func (p *Pool) tryClaimAndRun(ctx context.Context, logger *slog.Logger) (bool, error) { |
| 200 | p.mu.Lock() |
| 201 | kinds := make([]Kind, 0, len(p.handlers)) |
| 202 | for k := range p.handlers { |
| 203 | kinds = append(kinds, k) |
| 204 | } |
| 205 | p.mu.Unlock() |
| 206 | |
| 207 | any := false |
| 208 | for _, kind := range kinds { |
| 209 | if ctx.Err() != nil { |
| 210 | return any, ctx.Err() |
| 211 | } |
| 212 | ran, err := p.runOne(ctx, kind, logger) |
| 213 | if err != nil { |
| 214 | return any, err |
| 215 | } |
| 216 | if ran { |
| 217 | any = true |
| 218 | } |
| 219 | } |
| 220 | return any, nil |
| 221 | } |
| 222 | |
| 223 | // runOne claims one job of the given kind, runs it, and records the |
| 224 | // outcome. Returns ran=true when a job was claimed (regardless of |
| 225 | // success), false when the queue had nothing for this kind. |
| 226 | func (p *Pool) runOne(ctx context.Context, kind Kind, logger *slog.Logger) (bool, error) { |
| 227 | job, err := p.q.ClaimJob(ctx, p.db, workerdb.ClaimJobParams{ |
| 228 | Kind: string(kind), |
| 229 | LockedBy: pgtype.Text{String: p.cfg.InstanceID, Valid: true}, |
| 230 | }) |
| 231 | if errors.Is(err, pgx.ErrNoRows) { |
| 232 | return false, nil |
| 233 | } |
| 234 | if err != nil { |
| 235 | return false, fmt.Errorf("claim %s: %w", kind, err) |
| 236 | } |
| 237 | |
| 238 | p.mu.Lock() |
| 239 | h, ok := p.handlers[Kind(job.Kind)] |
| 240 | p.mu.Unlock() |
| 241 | if !ok { |
| 242 | // Registered handler vanished between claim and dispatch; fail |
| 243 | // the job rather than block the queue. Should never happen in |
| 244 | // practice — Register is one-shot at boot. |
| 245 | _ = p.q.MarkJobFailed(ctx, p.db, workerdb.MarkJobFailedParams{ |
| 246 | ID: job.ID, |
| 247 | LastError: pgtype.Text{String: "no handler registered", Valid: true}, |
| 248 | }) |
| 249 | return true, nil |
| 250 | } |
| 251 | |
| 252 | jobCtx, cancel := context.WithTimeout(ctx, p.cfg.JobTimeout) |
| 253 | start := time.Now() |
| 254 | metrics.WorkerInFlight.WithLabelValues(job.Kind).Inc() |
| 255 | runErr := safeRun(jobCtx, h, job.Payload) |
| 256 | metrics.WorkerInFlight.WithLabelValues(job.Kind).Dec() |
| 257 | cancel() |
| 258 | metrics.WorkerJobDurationSeconds.WithLabelValues(job.Kind).Observe(time.Since(start).Seconds()) |
| 259 | |
| 260 | logger.InfoContext( |
| 261 | ctx, "worker: dispatched", |
| 262 | "job_id", job.ID, |
| 263 | "kind", job.Kind, |
| 264 | "attempt", job.Attempts, |
| 265 | "duration_ms", time.Since(start).Milliseconds(), |
| 266 | "ok", runErr == nil, |
| 267 | ) |
| 268 | |
| 269 | if runErr == nil { |
| 270 | if err := p.q.MarkJobCompleted(ctx, p.db, job.ID); err != nil { |
| 271 | logger.ErrorContext(ctx, "worker: mark completed", "job_id", job.ID, "error", err) |
| 272 | } |
| 273 | metrics.WorkerJobsProcessedTotal.WithLabelValues(job.Kind, "ok").Inc() |
| 274 | return true, nil |
| 275 | } |
| 276 | |
| 277 | // Failure path. Poison errors skip retry. |
| 278 | if errors.Is(runErr, ErrPoison) { |
| 279 | _ = p.q.MarkJobFailed(ctx, p.db, workerdb.MarkJobFailedParams{ |
| 280 | ID: job.ID, |
| 281 | LastError: pgtype.Text{String: runErr.Error(), Valid: true}, |
| 282 | }) |
| 283 | metrics.WorkerJobsProcessedTotal.WithLabelValues(job.Kind, "poison").Inc() |
| 284 | return true, nil |
| 285 | } |
| 286 | |
| 287 | if int(job.Attempts) >= int(job.MaxAttempts) { |
| 288 | _ = p.q.MarkJobFailed(ctx, p.db, workerdb.MarkJobFailedParams{ |
| 289 | ID: job.ID, |
| 290 | LastError: pgtype.Text{String: runErr.Error(), Valid: true}, |
| 291 | }) |
| 292 | metrics.WorkerJobsProcessedTotal.WithLabelValues(job.Kind, "failed").Inc() |
| 293 | return true, nil |
| 294 | } |
| 295 | |
| 296 | p.mu.Lock() |
| 297 | delay := Backoff(int(job.Attempts), p.rng.Float64) |
| 298 | p.mu.Unlock() |
| 299 | _ = p.q.RescheduleJob(ctx, p.db, workerdb.RescheduleJobParams{ |
| 300 | ID: job.ID, |
| 301 | LastError: pgtype.Text{String: runErr.Error(), Valid: true}, |
| 302 | RunAt: pgtype.Timestamptz{Time: time.Now().Add(delay), Valid: true}, |
| 303 | }) |
| 304 | metrics.WorkerJobsProcessedTotal.WithLabelValues(job.Kind, "retry").Inc() |
| 305 | return true, nil |
| 306 | } |
| 307 | |
| 308 | // safeRun wraps the handler in a recover so a panicking handler doesn't |
| 309 | // take the worker goroutine with it; the job is rescheduled like any |
| 310 | // other failure. |
| 311 | func safeRun(ctx context.Context, h Handler, payload json.RawMessage) (err error) { |
| 312 | defer func() { |
| 313 | if r := recover(); r != nil { |
| 314 | err = fmt.Errorf("worker: handler panic: %v", r) |
| 315 | } |
| 316 | }() |
| 317 | return h(ctx, payload) |
| 318 | } |
| 319 |