Go · 11738 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 package main
4
5 import (
6 "bufio"
7 "context"
8 "errors"
9 "fmt"
10 "io"
11 "log/slog"
12 "os"
13 "path/filepath"
14 "strconv"
15 "strings"
16 "time"
17
18 "github.com/jackc/pgx/v5/pgtype"
19 "github.com/jackc/pgx/v5/pgxpool"
20 "github.com/spf13/cobra"
21
22 "github.com/tenseleyFlow/shithub/internal/auth/policy"
23 "github.com/tenseleyFlow/shithub/internal/infra/config"
24 "github.com/tenseleyFlow/shithub/internal/infra/db"
25 "github.com/tenseleyFlow/shithub/internal/infra/storage"
26 "github.com/tenseleyFlow/shithub/internal/repos/protection"
27 reposdb "github.com/tenseleyFlow/shithub/internal/repos/sqlc"
28 usersdb "github.com/tenseleyFlow/shithub/internal/users/sqlc"
29 "github.com/tenseleyFlow/shithub/internal/worker"
30 workerdb "github.com/tenseleyFlow/shithub/internal/worker/sqlc"
31 )
32
33 // hookCmd is the umbrella for `shithubd hook <name>`. Each named hook
34 // is a leaf subcommand; the symlink shim installed by hooks.Install
35 // invokes one of them. Hidden because no human runs these directly.
36 var hookCmd = &cobra.Command{
37 Use: "hook",
38 Short: "Git hook entrypoints (post-receive, pre-receive)",
39 Hidden: true,
40 }
41
42 // hookPreReceiveCmd implements the minimum-gates pre-receive hook
43 // described in S14. Full branch-protection gates land in S20.
44 //
45 // Stdin lines: "<old_sha> <new_sha> <ref>".
46 //
47 // Exit codes:
48 // - 0: accept the push.
49 // - 1: reject; git aborts and prints whatever we wrote to stderr.
50 //
51 // Latency budget: under 100ms for the common case (no archive/suspension).
52 // We re-check user/repo state from the DB to avoid trusting potentially
53 // stale env vars from long-lived SSH sessions.
54 var hookPreReceiveCmd = &cobra.Command{
55 Use: "pre-receive",
56 Short: "Hook: pre-receive — minimum-gates accept/reject",
57 Hidden: true,
58 RunE: func(cmd *cobra.Command, _ []string) error {
59 ctx, cancel := context.WithTimeout(cmd.Context(), 5*time.Second)
60 defer cancel()
61
62 hook, err := loadHookCtx(ctx)
63 if err != nil {
64 fmt.Fprintln(cmd.ErrOrStderr(), friendlyHookErr(err))
65 return err
66 }
67 defer hook.pool.Close()
68
69 refs, err := readRefLines(cmd.InOrStdin())
70 if err != nil {
71 fmt.Fprintln(cmd.ErrOrStderr(), "shithub: failed to read ref updates")
72 return err
73 }
74
75 if err := preReceiveCheck(ctx, hook); err != nil {
76 fmt.Fprintln(cmd.ErrOrStderr(), friendlyHookErr(err))
77 return err
78 }
79
80 // Branch-protection enforcement (S20). Per-ref check against
81 // the rule set; longest-pattern match wins. A single rejected
82 // ref aborts the entire push (git's standard non-atomic
83 // per-ref accept/reject still applies — pre-receive nonzero
84 // rejects all refs in the push, which matches our intent for
85 // "any rule says no, the whole push stops").
86 gitDir, err := repoGitDir(ctx, hook)
87 if err != nil {
88 fmt.Fprintln(cmd.ErrOrStderr(), friendlyHookErr(err))
89 return err
90 }
91 for _, rf := range refs {
92 d, perr := protection.Enforce(ctx, hook.pool, gitDir, hook.repoID, protection.Update{
93 OldSHA: rf.before, NewSHA: rf.after, Ref: rf.ref, Pusher: hook.userID,
94 })
95 if perr != nil {
96 fmt.Fprintln(cmd.ErrOrStderr(), "shithub: protection check failed (transient); please retry")
97 return perr
98 }
99 if !d.Allow {
100 fmt.Fprintln(cmd.ErrOrStderr(), protection.FriendlyMessage(d))
101 return errors.New("protection denied")
102 }
103 }
104 return nil
105 },
106 }
107
108 // hookPostReceiveCmd records each pushed ref as a push_events row,
109 // enqueues a push:process job per ref, and NOTIFYs idle workers.
110 // Latency budget: under 100ms for typical small pushes; we keep the
111 // hook to INSERT + NOTIFY + exit. No HTTP calls, no derivation work.
112 var hookPostReceiveCmd = &cobra.Command{
113 Use: "post-receive",
114 Short: "Hook: post-receive — enqueue async processing",
115 Hidden: true,
116 RunE: func(cmd *cobra.Command, _ []string) error {
117 ctx, cancel := context.WithTimeout(cmd.Context(), 5*time.Second)
118 defer cancel()
119
120 hook, err := loadHookCtx(ctx)
121 if err != nil {
122 // post-receive is non-fatal: the push has already landed. We
123 // log to stderr (the user's git client sees it) but exit 0
124 // so the push isn't reported as failed.
125 fmt.Fprintln(cmd.ErrOrStderr(), "shithub: warning: post-receive enqueue skipped:", err)
126 return nil
127 }
128 defer hook.pool.Close()
129
130 refs, err := readRefLines(cmd.InOrStdin())
131 if err != nil || len(refs) == 0 {
132 return nil
133 }
134
135 if err := postReceiveEnqueue(ctx, hook, refs); err != nil {
136 fmt.Fprintln(cmd.ErrOrStderr(), "shithub: warning: post-receive enqueue:", err)
137 }
138 return nil
139 },
140 }
141
142 // hookCtx bundles the deps each hook subcommand needs. Loaded once per
143 // invocation; closed by the caller via defer.
144 type hookCtx struct {
145 cfg config.Config
146 pool *pgxpool.Pool
147 logger *slog.Logger
148
149 userID int64
150 username string
151 repoID int64
152 repoFull string
153 protocol string
154 remoteIP string
155 requestID string
156 }
157
158 func loadHookCtx(ctx context.Context) (*hookCtx, error) {
159 cfg, err := config.Load(nil)
160 if err != nil {
161 return nil, fmt.Errorf("config: %w", err)
162 }
163 if cfg.DB.URL == "" {
164 return nil, errors.New("DB URL not set")
165 }
166
167 pool, err := db.Open(ctx, db.Config{
168 URL: cfg.DB.URL, MaxConns: 2, MinConns: 0,
169 ConnectTimeout: 1500 * time.Millisecond,
170 })
171 if err != nil {
172 return nil, fmt.Errorf("db: %w", err)
173 }
174
175 logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelInfo}))
176
177 uid, _ := strconv.ParseInt(os.Getenv("SHITHUB_USER_ID"), 10, 64)
178 rid, _ := strconv.ParseInt(os.Getenv("SHITHUB_REPO_ID"), 10, 64)
179 return &hookCtx{
180 cfg: cfg,
181 pool: pool,
182 logger: logger,
183 userID: uid,
184 username: os.Getenv("SHITHUB_USERNAME"),
185 repoID: rid,
186 repoFull: os.Getenv("SHITHUB_REPO_FULL_NAME"),
187 protocol: os.Getenv("SHITHUB_PROTOCOL"),
188 remoteIP: os.Getenv("SHITHUB_REMOTE_IP"),
189 requestID: os.Getenv("SHITHUB_REQUEST_ID"),
190 }, nil
191 }
192
193 // errHookGate is the typed error pre-receive returns for each rejection
194 // reason. friendlyHookErr maps these back to user-facing messages.
195 type errHookGate struct{ kind string }
196
197 func (e errHookGate) Error() string { return "shithub-hook: " + e.kind }
198
199 var (
200 errHookSuspended = errHookGate{"user suspended"}
201 errHookArchived = errHookGate{"repo archived"}
202 errHookDeleted = errHookGate{"repo deleted"}
203 errHookMissing = errHookGate{"missing context"}
204 errHookPermDenied = errHookGate{"permission denied"}
205 )
206
207 // repoGitDir resolves the bare-repo on-disk path for the hook's repo.
208 // Used by the protection enforcer's IsAncestor check. Hook env carries
209 // SHITHUB_REPO_FULL_NAME ("owner/name") so we don't need a DB hit.
210 func repoGitDir(ctx context.Context, h *hookCtx) (string, error) {
211 owner, name, ok := strings.Cut(h.repoFull, "/")
212 if !ok {
213 return "", fmt.Errorf("repoGitDir: bad repo full name %q", h.repoFull)
214 }
215 root, err := filepath.Abs(h.cfg.Storage.ReposRoot)
216 if err != nil {
217 return "", fmt.Errorf("repoGitDir: abs: %w", err)
218 }
219 rfs, err := storage.NewRepoFS(root)
220 if err != nil {
221 return "", fmt.Errorf("repoGitDir: fs: %w", err)
222 }
223 return rfs.RepoPath(owner, name)
224 }
225
226 func friendlyHookErr(err error) string {
227 switch {
228 case errors.Is(err, errHookSuspended):
229 return "shithub: your account is suspended; pushes are disabled."
230 case errors.Is(err, errHookArchived):
231 return "shithub: this repository is archived; pushes are disabled."
232 case errors.Is(err, errHookDeleted):
233 return "shithub: this repository has been deleted."
234 case errors.Is(err, errHookPermDenied):
235 return "shithub: you do not have write access to this repository."
236 case errors.Is(err, errHookMissing):
237 return "shithub: server error: hook context missing. Contact the operator."
238 default:
239 return "shithub: server error: " + err.Error()
240 }
241 }
242
243 func preReceiveCheck(ctx context.Context, h *hookCtx) error {
244 if h.userID == 0 || h.repoID == 0 {
245 return errHookMissing
246 }
247 uq := usersdb.New()
248 user, err := uq.GetUserByID(ctx, h.pool, h.userID)
249 if err != nil {
250 return fmt.Errorf("user lookup: %w", err)
251 }
252
253 rq := reposdb.New()
254 repo, err := rq.GetRepoByID(ctx, h.pool, h.repoID)
255 if err != nil {
256 return fmt.Errorf("repo lookup: %w", err)
257 }
258
259 actor := policy.UserActor(user.ID, user.Username, user.SuspendedAt.Valid, false)
260 repoRef := policy.NewRepoRefFromRepo(repo)
261 decision := policy.Can(ctx, policy.Deps{Pool: h.pool}, actor, policy.ActionRepoWrite, repoRef)
262 if decision.Allow {
263 return nil
264 }
265 switch decision.Code {
266 case policy.DenyRepoDeleted:
267 return errHookDeleted
268 case policy.DenyActorSuspended:
269 return errHookSuspended
270 case policy.DenyArchived:
271 return errHookArchived
272 default:
273 return errHookPermDenied
274 }
275 }
276
277 func postReceiveEnqueue(ctx context.Context, h *hookCtx, refs []refUpdate) error {
278 if h.repoID == 0 {
279 return errHookMissing
280 }
281
282 tx, err := h.pool.Begin(ctx)
283 if err != nil {
284 return fmt.Errorf("begin: %w", err)
285 }
286 committed := false
287 defer func() {
288 if !committed {
289 _ = tx.Rollback(ctx)
290 }
291 }()
292
293 wq := workerdb.New()
294 protocol := h.protocol
295 if protocol == "" {
296 protocol = "ssh" // safe fallback when env is missing
297 }
298 for _, r := range refs {
299 event, err := wq.InsertPushEvent(ctx, tx, workerdb.InsertPushEventParams{
300 RepoID: h.repoID,
301 BeforeSha: r.before,
302 AfterSha: r.after,
303 Ref: r.ref,
304 Protocol: protocol,
305 PusherUserID: pgtype.Int8{Int64: h.userID, Valid: h.userID != 0},
306 RequestID: pgtype.Text{String: h.requestID, Valid: h.requestID != ""},
307 })
308 if err != nil {
309 return fmt.Errorf("insert push_event: %w", err)
310 }
311 if _, err := worker.Enqueue(ctx, tx, worker.KindPushProcess,
312 map[string]any{"push_event_id": event.ID},
313 worker.EnqueueOptions{}); err != nil {
314 return fmt.Errorf("enqueue push:process: %w", err)
315 }
316 }
317 if err := worker.Notify(ctx, tx); err != nil {
318 // Notify failure inside tx is non-fatal — workers also poll.
319 h.logger.WarnContext(ctx, "post-receive: NOTIFY failed", "error", err)
320 }
321 if err := tx.Commit(ctx); err != nil {
322 return fmt.Errorf("commit: %w", err)
323 }
324 committed = true
325 return nil
326 }
327
328 // refUpdate is one stdin line as parsed by readRefLines.
329 type refUpdate struct {
330 before, after, ref string
331 }
332
333 func readRefLines(r io.Reader) ([]refUpdate, error) {
334 var out []refUpdate
335 sc := bufio.NewScanner(r)
336 sc.Buffer(make([]byte, 0, 64<<10), 1<<20)
337 for sc.Scan() {
338 line := strings.TrimSpace(sc.Text())
339 if line == "" {
340 continue
341 }
342 parts := strings.Fields(line)
343 if len(parts) != 3 {
344 continue
345 }
346 out = append(out, refUpdate{before: parts[0], after: parts[1], ref: parts[2]})
347 }
348 return out, sc.Err()
349 }
350
351 // hooksReinstallCmd reinstalls hook symlinks on every active repo, used
352 // after a binary path change in production deploys. --repo runs against
353 // a single owner/name; --all walks every repo via the DB.
354 var hooksReinstallCmd = &cobra.Command{
355 Use: "reinstall",
356 Short: "Reinstall hook symlinks on existing repos",
357 RunE: func(cmd *cobra.Command, _ []string) error {
358 all, _ := cmd.Flags().GetBool("all")
359 repo, _ := cmd.Flags().GetString("repo")
360 if !all && repo == "" {
361 return errors.New("hooks reinstall: pass --all or --repo owner/name")
362 }
363 return runHooksReinstall(cmd.Context(), all, repo, cmd.OutOrStdout())
364 },
365 }
366
367 // hooksParentCmd is the umbrella so the operator command reads as
368 // `shithubd hooks reinstall ...`.
369 var hooksParentCmd = &cobra.Command{
370 Use: "hooks",
371 Short: "Operator commands for git hook installation",
372 }
373
374 func init() {
375 hookCmd.AddCommand(hookPreReceiveCmd)
376 hookCmd.AddCommand(hookPostReceiveCmd)
377 hooksReinstallCmd.Flags().Bool("all", false, "Reinstall on every active repo")
378 hooksReinstallCmd.Flags().String("repo", "", "Reinstall on owner/name only")
379 hooksParentCmd.AddCommand(hooksReinstallCmd)
380
381 rootCmd.AddCommand(hookCmd)
382 rootCmd.AddCommand(hooksParentCmd)
383 }
384