Go · 3700 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 package main
4
5 import (
6 "errors"
7 "fmt"
8 "log/slog"
9 "os"
10 "os/signal"
11 "path/filepath"
12 "strconv"
13 "syscall"
14 "time"
15
16 "github.com/spf13/cobra"
17
18 "github.com/tenseleyFlow/shithub/internal/auth/audit"
19 "github.com/tenseleyFlow/shithub/internal/infra/config"
20 "github.com/tenseleyFlow/shithub/internal/infra/db"
21 "github.com/tenseleyFlow/shithub/internal/infra/storage"
22 "github.com/tenseleyFlow/shithub/internal/worker"
23 "github.com/tenseleyFlow/shithub/internal/worker/jobs"
24 )
25
26 // auditRecorder returns the shared audit recorder. Kept as a function
27 // rather than a package-level value so future tests / non-default
28 // recorders can substitute via dependency injection.
29 func auditRecorder() *audit.Recorder { return audit.NewRecorder() }
30
31 // workerCmd boots a long-running worker pool. SIGINT/SIGTERM trigger
32 // graceful shutdown: the LISTEN goroutine drops, claim attempts stop,
33 // in-flight jobs are given a deadline to finish, then the binary exits.
34 var workerCmd = &cobra.Command{
35 Use: "worker",
36 Short: "Run background workers (push processing, size recalc, purge)",
37 RunE: func(cmd *cobra.Command, _ []string) error {
38 workersFlag, _ := cmd.Flags().GetInt("workers")
39
40 cfg, err := config.Load(nil)
41 if err != nil {
42 return fmt.Errorf("config: %w", err)
43 }
44 if cfg.DB.URL == "" {
45 return errors.New("worker: SHITHUB_DATABASE_URL unset")
46 }
47 root, err := filepath.Abs(cfg.Storage.ReposRoot)
48 if err != nil {
49 return fmt.Errorf("repos_root: %w", err)
50 }
51 rfs, err := storage.NewRepoFS(root)
52 if err != nil {
53 return fmt.Errorf("repo fs: %w", err)
54 }
55
56 ctx, stop := signal.NotifyContext(cmd.Context(), os.Interrupt, syscall.SIGTERM)
57 defer stop()
58
59 // Worker count: flag overrides env override default.
60 count := workersFlag
61 if count <= 0 {
62 if v, _ := strconv.Atoi(os.Getenv("SHITHUB_WORKERS")); v > 0 {
63 count = v
64 }
65 }
66
67 pool, err := db.Open(ctx, db.Config{
68 URL: cfg.DB.URL, MaxConns: int32(count) + 2, MinConns: 1,
69 })
70 if err != nil {
71 return fmt.Errorf("db: %w", err)
72 }
73 defer pool.Close()
74
75 logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelInfo}))
76
77 p := worker.NewPool(pool, worker.PoolConfig{
78 Workers: count,
79 IdlePoll: 5 * time.Second,
80 JobTimeout: 5 * time.Minute,
81 Logger: logger,
82 })
83 p.Register(worker.KindPushProcess, jobs.PushProcess(jobs.PushProcessDeps{
84 Pool: pool, RepoFS: rfs, Logger: logger,
85 }))
86 p.Register(worker.KindRepoSizeRecalc, jobs.RepoSizeRecalc(jobs.RepoSizeRecalcDeps{
87 Pool: pool, RepoFS: rfs, Logger: logger,
88 }))
89 p.Register(worker.KindJobsPurge, jobs.JobsPurge(jobs.JobsPurgeDeps{
90 Pool: pool, Logger: logger,
91 }))
92 p.Register(worker.KindLifecycleSweep, jobs.LifecycleSweep(jobs.LifecycleSweepDeps{
93 Pool: pool, RepoFS: rfs, Audit: auditRecorder(), Logger: logger,
94 }))
95 prDeps := jobs.PRJobsDeps{Pool: pool, RepoFS: rfs, Logger: logger}
96 p.Register(worker.KindPRSynchronize, jobs.PRSynchronize(prDeps))
97 p.Register(worker.KindPRMergeability, jobs.PRMergeability(prDeps))
98
99 shithubdPath, _ := shithubdBinaryPath()
100 p.Register(worker.KindRepoForkClone, jobs.RepoForkClone(jobs.ForkCloneDeps{
101 Pool: pool, RepoFS: rfs, Logger: logger, ShithubdPath: shithubdPath,
102 }))
103
104 p.Register(worker.KindRepoIndexCode, jobs.RepoIndexCode(jobs.IndexCodeDeps{
105 Pool: pool, RepoFS: rfs, Logger: logger,
106 }))
107 p.Register(worker.KindRepoIndexReconcile, jobs.RepoIndexReconcile(jobs.IndexReconcileDeps{
108 Pool: pool, Logger: logger,
109 }))
110
111 return p.Run(ctx)
112 },
113 }
114
115 func init() {
116 workerCmd.Flags().Int("workers", 0, "Number of worker goroutines (default 4)")
117 rootCmd.AddCommand(workerCmd)
118 }
119