Go · 16096 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 // Package pulls owns the pull-request orchestrator. PRs reuse the S21
4 // `issues` row for title/body/state/timeline; this package owns the
5 // PR-specific surface — opening, synchronizing, mergeability detection,
6 // merge execution.
7 //
8 // Entry points are:
9 //
10 // Create — opens a PR (creates the issue row + the pull_requests row)
11 // Synchronize — refreshes commit + file lists + emits a synchronized event
12 // Mergeability — recomputes mergeable / mergeable_state via merge-tree
13 // Merge — performs the requested merge strategy in a temp worktree
14 // Edit — title/body
15 // SetState — close / reopen
16 // SetReady — draft → ready
17 package pulls
18
19 import (
20 "context"
21 "errors"
22 "fmt"
23 "log/slog"
24 "strings"
25
26 "github.com/jackc/pgx/v5"
27 "github.com/jackc/pgx/v5/pgtype"
28 "github.com/jackc/pgx/v5/pgxpool"
29
30 "github.com/tenseleyFlow/shithub/internal/auth/audit"
31 "github.com/tenseleyFlow/shithub/internal/checks"
32 "github.com/tenseleyFlow/shithub/internal/issues"
33 issuesdb "github.com/tenseleyFlow/shithub/internal/issues/sqlc"
34 mdrender "github.com/tenseleyFlow/shithub/internal/markdown"
35 "github.com/tenseleyFlow/shithub/internal/pulls/review"
36 pullsdb "github.com/tenseleyFlow/shithub/internal/pulls/sqlc"
37 repogit "github.com/tenseleyFlow/shithub/internal/repos/git"
38 "github.com/tenseleyFlow/shithub/internal/repos/protection"
39 reposdb "github.com/tenseleyFlow/shithub/internal/repos/sqlc"
40 )
41
42 // Deps wires this package against the rest of the runtime.
43 type Deps struct {
44 Pool *pgxpool.Pool
45 Logger *slog.Logger
46 // Audit is optional; when non-nil, Merge and SetState write audit
47 // rows. Mirrors the issues orchestrator's contract so PR-side
48 // state changes are equally traceable (S00-S25 audit, M).
49 Audit *audit.Recorder
50 }
51
52 // Errors surfaced to handlers.
53 var (
54 ErrSameBranch = errors.New("pulls: base and head must differ")
55 ErrBaseNotFound = errors.New("pulls: base ref not found")
56 ErrHeadNotFound = errors.New("pulls: head ref not found")
57 ErrNoCommitsToMerge = errors.New("pulls: head has no commits ahead of base")
58 ErrAlreadyMerged = errors.New("pulls: already merged")
59 ErrAlreadyClosed = errors.New("pulls: already closed")
60 ErrMergeBlocked = errors.New("pulls: merge blocked (mergeable_state != clean)")
61 ErrMergeMethodOff = errors.New("pulls: requested merge method is disabled on this repo")
62 ErrConcurrentMerge = errors.New("pulls: PR is being merged by another request")
63 ErrPRNotFound = errors.New("pulls: PR not found")
64 )
65
66 // CreateParams describes a new-PR request.
67 type CreateParams struct {
68 RepoID int64
69 AuthorUserID int64
70 Title string
71 Body string
72 BaseRef string
73 HeadRef string
74 Draft bool
75 GitDir string // resolved from RepoFS by the caller
76 }
77
78 // CreateResult bundles the issue row + the PR row (post-snapshot).
79 type CreateResult struct {
80 Issue issuesdb.Issue
81 PullRequest pullsdb.PullRequest
82 }
83
84 // Create opens a PR. Validates that base/head are distinct and resolve
85 // in the on-disk repo, snapshots their OIDs, then creates the issues
86 // row + pull_requests row in one tx. Mergeability is `unknown` until
87 // the worker job ticks.
88 func Create(ctx context.Context, deps Deps, p CreateParams) (CreateResult, error) {
89 base := strings.TrimSpace(p.BaseRef)
90 head := strings.TrimSpace(p.HeadRef)
91 if base == "" || head == "" || base == head {
92 return CreateResult{}, ErrSameBranch
93 }
94 baseOID, err := repogit.ResolveRefOID(ctx, p.GitDir, base)
95 if err != nil {
96 if errors.Is(err, repogit.ErrRefNotFound) {
97 return CreateResult{}, ErrBaseNotFound
98 }
99 return CreateResult{}, fmt.Errorf("resolve base: %w", err)
100 }
101 headOID, err := repogit.ResolveRefOID(ctx, p.GitDir, head)
102 if err != nil {
103 if errors.Is(err, repogit.ErrRefNotFound) {
104 return CreateResult{}, ErrHeadNotFound
105 }
106 return CreateResult{}, fmt.Errorf("resolve head: %w", err)
107 }
108 if baseOID == headOID {
109 return CreateResult{}, ErrNoCommitsToMerge
110 }
111
112 // Open the issues row first via the issues orchestrator so we get
113 // the per-repo number allocation, body markdown render, and
114 // reference indexing for free.
115 issueRow, err := issues.Create(ctx, issues.Deps{Pool: deps.Pool, Logger: deps.Logger}, issues.CreateParams{
116 RepoID: p.RepoID,
117 AuthorUserID: p.AuthorUserID,
118 Title: p.Title,
119 Body: p.Body,
120 Kind: "pr",
121 })
122 if err != nil {
123 return CreateResult{}, err
124 }
125
126 prRow, err := pullsdb.New().CreatePullRequest(ctx, deps.Pool, pullsdb.CreatePullRequestParams{
127 IssueID: issueRow.ID,
128 BaseRef: base,
129 HeadRef: head,
130 HeadRepoID: p.RepoID,
131 BaseOid: baseOID,
132 HeadOid: headOID,
133 Draft: p.Draft,
134 })
135 if err != nil {
136 return CreateResult{}, fmt.Errorf("create pull_request: %w", err)
137 }
138
139 // Best-effort initial synchronize so the PR view has commits + files
140 // even before the worker queue runs. Failures here don't fail the
141 // open — the worker will retry on the next tick.
142 if err := refreshCommitsAndFiles(ctx, deps, p.GitDir, prRow.IssueID, baseOID, headOID); err != nil {
143 if deps.Logger != nil {
144 deps.Logger.WarnContext(ctx, "pulls: initial sync", "error", err, "pr_id", prRow.IssueID)
145 }
146 }
147
148 return CreateResult{Issue: issueRow, PullRequest: prRow}, nil
149 }
150
151 // refreshCommitsAndFiles is shared by Create + Synchronize. Truncates +
152 // re-fills `pull_request_commits` and `pull_request_files`.
153 func refreshCommitsAndFiles(ctx context.Context, deps Deps, gitDir string, prID int64, baseOID, headOID string) error {
154 commits, err := repogit.CommitsBetweenDetail(ctx, gitDir, baseOID, headOID, 250)
155 if err != nil {
156 return fmt.Errorf("commits: %w", err)
157 }
158 files, err := repogit.FilesChangedBetween(ctx, gitDir, baseOID, headOID)
159 if err != nil {
160 return fmt.Errorf("files: %w", err)
161 }
162 tx, err := deps.Pool.Begin(ctx)
163 if err != nil {
164 return err
165 }
166 committed := false
167 defer func() {
168 if !committed {
169 _ = tx.Rollback(ctx)
170 }
171 }()
172 q := pullsdb.New()
173 if err := q.ClearPullRequestCommits(ctx, tx, prID); err != nil {
174 return err
175 }
176 if err := q.ClearPullRequestFiles(ctx, tx, prID); err != nil {
177 return err
178 }
179 for i, c := range commits {
180 var ats, cts pgtype.Timestamptz
181 if !c.AuthorWhen.IsZero() {
182 ats = pgtype.Timestamptz{Time: c.AuthorWhen, Valid: true}
183 }
184 if !c.CommitterWhen.IsZero() {
185 cts = pgtype.Timestamptz{Time: c.CommitterWhen, Valid: true}
186 }
187 if err := q.InsertPullRequestCommit(ctx, tx, pullsdb.InsertPullRequestCommitParams{
188 PrID: prID,
189 Sha: c.OID,
190 Position: int32(i),
191 AuthorName: c.AuthorName,
192 AuthorEmail: c.AuthorEmail,
193 CommitterName: c.CommitterName,
194 CommitterEmail: c.CommitterEmail,
195 Subject: c.Subject,
196 Body: c.Body,
197 AuthoredAt: ats,
198 CommittedAt: cts,
199 }); err != nil {
200 return err
201 }
202 }
203 for _, f := range files {
204 oldPath := pgtype.Text{}
205 if f.OldPath != "" {
206 oldPath = pgtype.Text{String: f.OldPath, Valid: true}
207 }
208 if err := q.InsertPullRequestFile(ctx, tx, pullsdb.InsertPullRequestFileParams{
209 PrID: prID,
210 Path: f.Path,
211 Status: pullsdb.PrFileStatus(f.Status),
212 OldPath: oldPath,
213 Additions: int32(f.Additions),
214 Deletions: int32(f.Deletions),
215 Changes: int32(f.Additions + f.Deletions),
216 }); err != nil {
217 return err
218 }
219 }
220 if err := q.SetPullRequestSnapshot(ctx, tx, pullsdb.SetPullRequestSnapshotParams{
221 IssueID: prID, BaseOid: baseOID, HeadOid: headOID,
222 }); err != nil {
223 return err
224 }
225 if err := tx.Commit(ctx); err != nil {
226 return err
227 }
228 committed = true
229 return nil
230 }
231
232 // Synchronize re-snapshots the PR's base/head OIDs, refreshes the
233 // commits + files lists, and emits a `synchronized` event into the
234 // issue timeline. Called from the pr:synchronize worker job after
235 // any push to the head ref.
236 func Synchronize(ctx context.Context, deps Deps, gitDir string, prID int64) error {
237 q := pullsdb.New()
238 pr, err := q.GetPullRequestByIssueID(ctx, deps.Pool, prID)
239 if err != nil {
240 if errors.Is(err, pgx.ErrNoRows) {
241 return ErrPRNotFound
242 }
243 return err
244 }
245 baseOID, err := repogit.ResolveRefOID(ctx, gitDir, pr.BaseRef)
246 if err != nil {
247 return fmt.Errorf("resolve base: %w", err)
248 }
249 headOID, err := repogit.ResolveRefOID(ctx, gitDir, pr.HeadRef)
250 if err != nil {
251 return fmt.Errorf("resolve head: %w", err)
252 }
253 if err := refreshCommitsAndFiles(ctx, deps, gitDir, prID, baseOID, headOID); err != nil {
254 return err
255 }
256 // Re-anchor review comments against the new snapshot. Comments
257 // whose original line still exists keep their thread; the rest
258 // outdate (current_position=NULL) and surface in the "Show
259 // outdated" toggle of the Files tab.
260 if err := review.RemapAllForPR(ctx, review.Deps{Pool: deps.Pool, Logger: deps.Logger}, gitDir, prID, baseOID, headOID); err != nil {
261 // Best-effort: a position-map miss shouldn't block the sync
262 // pipeline. Log + continue.
263 if deps.Logger != nil {
264 deps.Logger.WarnContext(ctx, "pulls: position remap", "error", err, "pr_id", prID)
265 }
266 }
267 // Reset mergeability to unknown so the next mergeability tick
268 // recomputes against the fresh snapshot.
269 if err := q.SetPullRequestMergeability(ctx, deps.Pool, pullsdb.SetPullRequestMergeabilityParams{
270 IssueID: prID,
271 Mergeable: pgtype.Bool{},
272 MergeableState: pullsdb.PrMergeableStateUnknown,
273 }); err != nil {
274 return fmt.Errorf("reset mergeability: %w", err)
275 }
276 // Emit the synchronized timeline event.
277 iq := issuesdb.New()
278 if _, err := iq.InsertIssueEvent(ctx, deps.Pool, issuesdb.InsertIssueEventParams{
279 IssueID: prID,
280 Kind: "synchronized",
281 Meta: []byte(fmt.Sprintf(`{"head_oid":%q}`, headOID)),
282 }); err != nil {
283 return fmt.Errorf("emit event: %w", err)
284 }
285 return nil
286 }
287
288 // Mergeability runs the merge-tree probe and persists the result.
289 // Order of state checks (highest priority first):
290 //
291 // dirty — git merge-tree reports conflicts
292 // behind — head has no commits ahead of base
293 // blocked — required reviews missing OR an undismissed
294 // request_changes review exists (S23 gate)
295 // clean — merge-tree clean and review gate satisfied
296 //
297 // `blocked` is set by the S23 review evaluator; when no protection
298 // rule applies and no request_changes review exists, the gate is a
299 // no-op and we fall through to clean.
300 func Mergeability(ctx context.Context, deps Deps, gitDir string, prID int64) error {
301 q := pullsdb.New()
302 pr, err := q.GetPullRequestByIssueID(ctx, deps.Pool, prID)
303 if err != nil {
304 return err
305 }
306 if pr.BaseOid == "" || pr.HeadOid == "" {
307 return nil // synchronize hasn't run yet; nothing to probe
308 }
309 // Behind: head has no commits ahead of base.
310 commits, err := repogit.CommitsBetweenDetail(ctx, gitDir, pr.BaseOid, pr.HeadOid, 1)
311 if err != nil && !errors.Is(err, repogit.ErrRefNotFound) {
312 return err
313 }
314 if len(commits) == 0 {
315 return q.SetPullRequestMergeability(ctx, deps.Pool, pullsdb.SetPullRequestMergeabilityParams{
316 IssueID: prID,
317 Mergeable: pgtype.Bool{Bool: false, Valid: true},
318 MergeableState: pullsdb.PrMergeableStateBehind,
319 })
320 }
321 res, err := repogit.ProbeMerge(ctx, gitDir, pr.BaseOid, pr.HeadOid)
322 if err != nil {
323 return fmt.Errorf("probe: %w", err)
324 }
325 if res.HasConflict {
326 return q.SetPullRequestMergeability(ctx, deps.Pool, pullsdb.SetPullRequestMergeabilityParams{
327 IssueID: prID,
328 Mergeable: pgtype.Bool{Bool: false, Valid: true},
329 MergeableState: pullsdb.PrMergeableStateDirty,
330 })
331 }
332 // Composed gate: review (S23) + required-checks (S24). Either one
333 // failing produces `blocked`. The two evaluators are independent;
334 // each loads its own slice of the protection rule.
335 issue, err := issuesdb.New().GetIssueByID(ctx, deps.Pool, prID)
336 if err != nil {
337 return fmt.Errorf("load issue: %w", err)
338 }
339 reviewGate, err := review.Evaluate(ctx, deps.Pool, review.GateInputs{
340 RepoID: issue.RepoID,
341 BaseRef: pr.BaseRef,
342 PRIssueID: prID,
343 }, int64FromPg(issue.AuthorUserID))
344 if err != nil {
345 return fmt.Errorf("review gate: %w", err)
346 }
347 requiredCheckNames, err := loadRequiredCheckNames(ctx, deps.Pool, issue.RepoID, pr.BaseRef)
348 if err != nil {
349 return fmt.Errorf("required-check rule lookup: %w", err)
350 }
351 checksGate, err := checks.EvaluateRequiredChecks(ctx, deps.Pool, checks.GateInputs{
352 RepoID: issue.RepoID,
353 HeadSHA: pr.HeadOid,
354 RequiredNames: requiredCheckNames,
355 })
356 if err != nil {
357 return fmt.Errorf("checks gate: %w", err)
358 }
359 state := pullsdb.PrMergeableStateClean
360 mergeable := true
361 if !reviewGate.Satisfied || !checksGate.Satisfied {
362 state = pullsdb.PrMergeableStateBlocked
363 mergeable = false
364 }
365 return q.SetPullRequestMergeability(ctx, deps.Pool, pullsdb.SetPullRequestMergeabilityParams{
366 IssueID: prID,
367 Mergeable: pgtype.Bool{Bool: mergeable, Valid: true},
368 MergeableState: state,
369 })
370 }
371
372 // loadRequiredCheckNames returns the `status_checks_required` list
373 // from the longest-pattern-matching protection rule for `baseRef`.
374 // Empty slice means no rule, no required checks.
375 func loadRequiredCheckNames(ctx context.Context, pool *pgxpool.Pool, repoID int64, baseRef string) ([]string, error) {
376 rules, err := reposdb.New().ListBranchProtectionRules(ctx, pool, repoID)
377 if err != nil {
378 return nil, err
379 }
380 rule, ok := protection.MatchLongestRule(rules, baseRef)
381 if !ok {
382 return []string{}, nil
383 }
384 return rule.StatusChecksRequired, nil
385 }
386
387 // int64FromPg unwraps a pgtype.Int8; returns 0 when invalid.
388 func int64FromPg(p pgtype.Int8) int64 {
389 if !p.Valid {
390 return 0
391 }
392 return p.Int64
393 }
394
395 // EditPR updates the PR's title + body. Body markdown is re-rendered
396 // via the same pipeline issues.Create uses so HTML is consistent.
397 func EditPR(ctx context.Context, deps Deps, prID int64, title, body string) error {
398 title = strings.TrimSpace(title)
399 if title == "" {
400 return issues.ErrEmptyTitle
401 }
402 if len(title) > 256 {
403 return issues.ErrTitleTooLong
404 }
405 if len(body) > 65535 {
406 return issues.ErrBodyTooLong
407 }
408 html := renderBodyHTML(ctx, deps, body)
409 q := issuesdb.New()
410 return q.UpdateIssueTitleBody(ctx, deps.Pool, issuesdb.UpdateIssueTitleBodyParams{
411 ID: prID,
412 Title: title,
413 Body: body,
414 BodyHtmlCached: pgtype.Text{String: html, Valid: html != ""},
415 })
416 }
417
418 // SetReady flips draft → false and emits a `ready_for_review` event.
419 func SetReady(ctx context.Context, deps Deps, actorUserID, prID int64) error {
420 q := pullsdb.New()
421 tx, err := deps.Pool.Begin(ctx)
422 if err != nil {
423 return err
424 }
425 committed := false
426 defer func() {
427 if !committed {
428 _ = tx.Rollback(ctx)
429 }
430 }()
431 if err := q.SetPullRequestDraft(ctx, tx, pullsdb.SetPullRequestDraftParams{IssueID: prID, Draft: false}); err != nil {
432 return err
433 }
434 iq := issuesdb.New()
435 if _, err := iq.InsertIssueEvent(ctx, tx, issuesdb.InsertIssueEventParams{
436 IssueID: prID,
437 ActorUserID: pgtype.Int8{Int64: actorUserID, Valid: actorUserID != 0},
438 Kind: "ready_for_review",
439 Meta: []byte("{}"),
440 }); err != nil {
441 return err
442 }
443 if err := tx.Commit(ctx); err != nil {
444 return err
445 }
446 committed = true
447 return nil
448 }
449
450 // AllowedMethod returns true when the repo allows the named merge
451 // strategy. Falls open for unknown methods so callers get a clear
452 // error from the orchestrator.
453 func AllowedMethod(repo reposdb.Repo, method string) bool {
454 switch method {
455 case "merge":
456 return repo.AllowMergeCommit
457 case "squash":
458 return repo.AllowSquashMerge
459 case "rebase":
460 return repo.AllowRebaseMerge
461 }
462 return false
463 }
464
465 // renderBodyHTML wraps markdown.RenderHTML with a logger-aware error
466 // path. PR body length is bounded upstream at 65535 chars by the
467 // orchestrator; markdown caps at 1 MiB. ErrInputTooLarge here means
468 // a precondition regressed — log loudly. (S00-S25 audit, M.)
469 func renderBodyHTML(ctx context.Context, deps Deps, body string) string {
470 html, err := mdrender.RenderHTML([]byte(body))
471 if err != nil && deps.Logger != nil {
472 deps.Logger.WarnContext(ctx, "pulls: markdown render failed",
473 "error", err, "body_bytes", len(body))
474 }
475 return html
476 }
477