Go · 11985 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 package pulls_test
4
5 import (
6 "context"
7 "io"
8 "log/slog"
9 "os"
10 "os/exec"
11 "path/filepath"
12 "strings"
13 "sync"
14 "testing"
15
16 "github.com/jackc/pgx/v5/pgtype"
17 "github.com/jackc/pgx/v5/pgxpool"
18
19 issuesdb "github.com/tenseleyFlow/shithub/internal/issues/sqlc"
20 "github.com/tenseleyFlow/shithub/internal/pulls"
21 pullsdb "github.com/tenseleyFlow/shithub/internal/pulls/sqlc"
22 reposdb "github.com/tenseleyFlow/shithub/internal/repos/sqlc"
23 "github.com/tenseleyFlow/shithub/internal/testing/dbtest"
24 usersdb "github.com/tenseleyFlow/shithub/internal/users/sqlc"
25 )
26
27 const fixtureHash = "$argon2id$v=19$m=16384,t=1,p=1$" +
28 "AAAAAAAAAAAAAAAA$" +
29 "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"
30
31 // gitCmd suppresses the gosec G204 noise — every invocation runs
32 // against a t.TempDir path the test set up.
33 func gitCmd(args ...string) *exec.Cmd {
34 //nolint:gosec
35 return exec.Command("git", args...)
36 }
37
38 // fixture spins a real bare git repo on disk + a DB row pair (user
39 // + repo + ensured issue counter) so the orchestrator's path-on-disk
40 // assumptions hold.
41 type fixture struct {
42 pool *pgxpool.Pool
43 deps pulls.Deps
44 userID int64
45 repoID int64
46 gitDir string
47 }
48
49 func setup(t *testing.T) fixture {
50 t.Helper()
51 pool := dbtest.NewTestDB(t)
52 ctx := context.Background()
53
54 uq := usersdb.New()
55 user, err := uq.CreateUser(ctx, pool, usersdb.CreateUserParams{
56 Username: "alice", DisplayName: "Alice", PasswordHash: fixtureHash,
57 })
58 if err != nil {
59 t.Fatalf("CreateUser: %v", err)
60 }
61 em, err := uq.CreateUserEmail(ctx, pool, usersdb.CreateUserEmailParams{
62 UserID: user.ID, Email: "alice@example.com", IsPrimary: true, Verified: true,
63 })
64 if err != nil {
65 t.Fatalf("CreateUserEmail: %v", err)
66 }
67 if err := uq.LinkUserPrimaryEmail(ctx, pool, usersdb.LinkUserPrimaryEmailParams{
68 ID: user.ID, PrimaryEmailID: pgtype.Int8{Int64: em.ID, Valid: true},
69 }); err != nil {
70 t.Fatalf("LinkUserPrimaryEmail: %v", err)
71 }
72
73 rq := reposdb.New()
74 repo, err := rq.CreateRepo(ctx, pool, reposdb.CreateRepoParams{
75 OwnerUserID: pgtype.Int8{Int64: user.ID, Valid: true},
76 Name: "demo",
77 DefaultBranch: "trunk",
78 Visibility: reposdb.RepoVisibilityPublic,
79 })
80 if err != nil {
81 t.Fatalf("CreateRepo: %v", err)
82 }
83
84 iq := issuesdb.New()
85 if err := iq.EnsureRepoIssueCounter(ctx, pool, repo.ID); err != nil {
86 t.Fatalf("EnsureRepoIssueCounter: %v", err)
87 }
88
89 root := t.TempDir()
90 gitDir := filepath.Join(root, "demo.git")
91 if out, err := gitCmd("init", "--bare", "-b", "trunk", gitDir).CombinedOutput(); err != nil {
92 t.Fatalf("git init --bare: %v (%s)", err, out)
93 }
94
95 w := io.Discard
96 if testing.Verbose() {
97 w = os.Stderr
98 }
99 deps := pulls.Deps{
100 Pool: pool,
101 Logger: slog.New(slog.NewTextHandler(w, nil)),
102 }
103 return fixture{pool: pool, deps: deps, userID: user.ID, repoID: repo.ID, gitDir: gitDir}
104 }
105
106 // commitOnBranch creates a commit on branch from a temp worktree.
107 // Returns the new HEAD oid.
108 func commitOnBranch(t *testing.T, gitDir, branch, msg, file, contents string) string {
109 t.Helper()
110 wt := t.TempDir()
111 // Add a worktree that creates the branch if missing.
112 addArgs := []string{"-C", gitDir, "worktree", "add"}
113 // If branch doesn't exist yet, create it; otherwise check it out.
114 if _, err := gitCmd("-C", gitDir, "show-ref", "--verify", "refs/heads/"+branch).CombinedOutput(); err != nil {
115 addArgs = append(addArgs, "-b", branch, wt)
116 } else {
117 addArgs = append(addArgs, wt, branch)
118 }
119 if out, err := gitCmd(addArgs...).CombinedOutput(); err != nil {
120 t.Fatalf("worktree add %s: %v (%s)", branch, err, out)
121 }
122 defer func() {
123 _ = gitCmd("-C", gitDir, "worktree", "remove", "--force", wt).Run()
124 }()
125
126 if err := os.WriteFile(filepath.Join(wt, file), []byte(contents), 0o644); err != nil { //nolint:gosec
127 t.Fatalf("write %s: %v", file, err)
128 }
129 for _, args := range [][]string{
130 {"-C", wt, "config", "user.name", "Alice"},
131 {"-C", wt, "config", "user.email", "alice@example.com"},
132 {"-C", wt, "add", "."},
133 {"-C", wt, "commit", "-m", msg},
134 } {
135 if out, err := gitCmd(args...).CombinedOutput(); err != nil {
136 t.Fatalf("%v: %v (%s)", args, err, out)
137 }
138 }
139 out, err := gitCmd("-C", wt, "rev-parse", "HEAD").Output()
140 if err != nil {
141 t.Fatalf("rev-parse HEAD: %v", err)
142 }
143 return strings.TrimSpace(string(out))
144 }
145
146 func TestCreate_OpensPRWithIssueRow(t *testing.T) {
147 f := setup(t)
148 commitOnBranch(t, f.gitDir, "trunk", "init", "README.md", "hi\n")
149 commitOnBranch(t, f.gitDir, "feature", "add foo", "foo.txt", "foo\n")
150
151 res, err := pulls.Create(context.Background(), f.deps, pulls.CreateParams{
152 RepoID: f.repoID,
153 AuthorUserID: f.userID,
154 Title: "Add foo",
155 Body: "fixes nothing yet",
156 BaseRef: "trunk",
157 HeadRef: "feature",
158 GitDir: f.gitDir,
159 })
160 if err != nil {
161 t.Fatalf("Create: %v", err)
162 }
163 if res.Issue.Kind != issuesdb.IssueKindPr {
164 t.Errorf("issue kind: got %s, want pr", res.Issue.Kind)
165 }
166 if res.PullRequest.BaseRef != "trunk" || res.PullRequest.HeadRef != "feature" {
167 t.Errorf("ref mismatch: %+v", res.PullRequest)
168 }
169 if res.PullRequest.BaseOid == "" || res.PullRequest.HeadOid == "" {
170 t.Errorf("OIDs not snapshotted: %+v", res.PullRequest)
171 }
172 commits, _ := pullsdb.New().ListPullRequestCommits(context.Background(), f.pool, res.PullRequest.IssueID)
173 if len(commits) == 0 {
174 t.Errorf("expected commits populated by initial sync")
175 }
176 }
177
178 func TestCreate_RejectsSameBranch(t *testing.T) {
179 f := setup(t)
180 commitOnBranch(t, f.gitDir, "trunk", "init", "README.md", "hi\n")
181 _, err := pulls.Create(context.Background(), f.deps, pulls.CreateParams{
182 RepoID: f.repoID, AuthorUserID: f.userID,
183 Title: "x", BaseRef: "trunk", HeadRef: "trunk", GitDir: f.gitDir,
184 })
185 if err == nil {
186 t.Fatalf("expected ErrSameBranch, got nil")
187 }
188 }
189
190 func TestMergeability_Clean(t *testing.T) {
191 f := setup(t)
192 commitOnBranch(t, f.gitDir, "trunk", "init", "README.md", "hi\n")
193 commitOnBranch(t, f.gitDir, "feature", "add foo", "foo.txt", "foo\n")
194 res, err := pulls.Create(context.Background(), f.deps, pulls.CreateParams{
195 RepoID: f.repoID, AuthorUserID: f.userID,
196 Title: "x", BaseRef: "trunk", HeadRef: "feature", GitDir: f.gitDir,
197 })
198 if err != nil {
199 t.Fatalf("Create: %v", err)
200 }
201 if err := pulls.Mergeability(context.Background(), f.deps, f.gitDir, res.PullRequest.IssueID); err != nil {
202 t.Fatalf("Mergeability: %v", err)
203 }
204 pr, _ := pullsdb.New().GetPullRequestByIssueID(context.Background(), f.pool, res.PullRequest.IssueID)
205 if pr.MergeableState != pullsdb.PrMergeableStateClean {
206 t.Errorf("got %s, want clean", pr.MergeableState)
207 }
208 }
209
210 func TestMergeability_Dirty(t *testing.T) {
211 f := setup(t)
212 commitOnBranch(t, f.gitDir, "trunk", "init", "shared.txt", "base content\n")
213 // Modify shared.txt on trunk.
214 commitOnBranch(t, f.gitDir, "trunk", "trunk edit", "shared.txt", "trunk content\n")
215 // Branch from earlier trunk and also edit shared.txt → conflict.
216 // Create the feature branch from the first trunk commit.
217 out, err := gitCmd("-C", f.gitDir, "rev-list", "--reverse", "trunk").Output()
218 if err != nil {
219 t.Fatalf("rev-list: %v", err)
220 }
221 firstSHA := strings.SplitN(strings.TrimSpace(string(out)), "\n", 2)[0]
222 if out, err := gitCmd("-C", f.gitDir, "branch", "feature", firstSHA).CombinedOutput(); err != nil {
223 t.Fatalf("create feature branch: %v (%s)", err, out)
224 }
225 commitOnBranch(t, f.gitDir, "feature", "feature edit", "shared.txt", "feature content\n")
226
227 res, err := pulls.Create(context.Background(), f.deps, pulls.CreateParams{
228 RepoID: f.repoID, AuthorUserID: f.userID,
229 Title: "x", BaseRef: "trunk", HeadRef: "feature", GitDir: f.gitDir,
230 })
231 if err != nil {
232 t.Fatalf("Create: %v", err)
233 }
234 if err := pulls.Mergeability(context.Background(), f.deps, f.gitDir, res.PullRequest.IssueID); err != nil {
235 t.Fatalf("Mergeability: %v", err)
236 }
237 pr, _ := pullsdb.New().GetPullRequestByIssueID(context.Background(), f.pool, res.PullRequest.IssueID)
238 if pr.MergeableState != pullsdb.PrMergeableStateDirty {
239 t.Errorf("got %s, want dirty", pr.MergeableState)
240 }
241 }
242
243 func TestMerge_MergeCommit(t *testing.T) {
244 f := setup(t)
245 commitOnBranch(t, f.gitDir, "trunk", "init", "README.md", "hi\n")
246 commitOnBranch(t, f.gitDir, "feature", "add foo", "foo.txt", "foo\n")
247 res, err := pulls.Create(context.Background(), f.deps, pulls.CreateParams{
248 RepoID: f.repoID, AuthorUserID: f.userID,
249 Title: "Add foo", BaseRef: "trunk", HeadRef: "feature", GitDir: f.gitDir,
250 })
251 if err != nil {
252 t.Fatalf("Create: %v", err)
253 }
254 if err := pulls.Mergeability(context.Background(), f.deps, f.gitDir, res.PullRequest.IssueID); err != nil {
255 t.Fatalf("Mergeability: %v", err)
256 }
257 if err := pulls.Merge(context.Background(), f.deps, pulls.MergeParams{
258 PRID: res.PullRequest.IssueID, ActorUserID: f.userID,
259 GitDir: f.gitDir, Method: "merge",
260 }); err != nil {
261 t.Fatalf("Merge: %v", err)
262 }
263 pr, _ := pullsdb.New().GetPullRequestByIssueID(context.Background(), f.pool, res.PullRequest.IssueID)
264 if !pr.MergedAt.Valid {
265 t.Errorf("merged_at not set")
266 }
267 if !pr.MergeCommitSha.Valid {
268 t.Errorf("merge_commit_sha not set")
269 }
270 // Issue side closed?
271 iq := issuesdb.New()
272 issue, _ := iq.GetIssueByID(context.Background(), f.pool, res.PullRequest.IssueID)
273 if issue.State != issuesdb.IssueStateClosed {
274 t.Errorf("issue state: got %s, want closed", issue.State)
275 }
276 }
277
278 func TestMerge_RejectsConcurrentDouble(t *testing.T) {
279 f := setup(t)
280 commitOnBranch(t, f.gitDir, "trunk", "init", "README.md", "hi\n")
281 commitOnBranch(t, f.gitDir, "feature", "add foo", "foo.txt", "foo\n")
282 res, err := pulls.Create(context.Background(), f.deps, pulls.CreateParams{
283 RepoID: f.repoID, AuthorUserID: f.userID,
284 Title: "x", BaseRef: "trunk", HeadRef: "feature", GitDir: f.gitDir,
285 })
286 if err != nil {
287 t.Fatalf("Create: %v", err)
288 }
289 _ = pulls.Mergeability(context.Background(), f.deps, f.gitDir, res.PullRequest.IssueID)
290
291 var wg sync.WaitGroup
292 errs := make([]error, 2)
293 for i := 0; i < 2; i++ {
294 wg.Add(1)
295 go func(i int) {
296 defer wg.Done()
297 errs[i] = pulls.Merge(context.Background(), f.deps, pulls.MergeParams{
298 PRID: res.PullRequest.IssueID, ActorUserID: f.userID,
299 GitDir: f.gitDir, Method: "merge",
300 })
301 }(i)
302 }
303 wg.Wait()
304
305 successes := 0
306 for _, e := range errs {
307 if e == nil {
308 successes++
309 }
310 }
311 if successes != 1 {
312 t.Errorf("expected exactly one successful merge, got %d (errors: %v, %v)", successes, errs[0], errs[1])
313 }
314 }
315
316 func TestMerge_LinkedIssueAutoClose(t *testing.T) {
317 f := setup(t)
318 ctx := context.Background()
319
320 // Create an issue first so the PR body can reference it.
321 iq := issuesdb.New()
322 num, err := iq.AllocateIssueNumber(ctx, f.pool, f.repoID)
323 if err != nil {
324 t.Fatalf("AllocateIssueNumber: %v", err)
325 }
326 issue, err := iq.CreateIssue(ctx, f.pool, issuesdb.CreateIssueParams{
327 RepoID: f.repoID,
328 Number: num,
329 Kind: issuesdb.IssueKindIssue,
330 Title: "bug",
331 Body: "fix me",
332 AuthorUserID: pgtype.Int8{Int64: f.userID, Valid: true},
333 })
334 if err != nil {
335 t.Fatalf("CreateIssue: %v", err)
336 }
337
338 commitOnBranch(t, f.gitDir, "trunk", "init", "README.md", "hi\n")
339 commitOnBranch(t, f.gitDir, "feature", "add foo", "foo.txt", "foo\n")
340
341 res, err := pulls.Create(ctx, f.deps, pulls.CreateParams{
342 RepoID: f.repoID, AuthorUserID: f.userID,
343 Title: "fix the bug", Body: "Fixes #1",
344 BaseRef: "trunk", HeadRef: "feature", GitDir: f.gitDir,
345 })
346 if err != nil {
347 t.Fatalf("Create PR: %v", err)
348 }
349 _ = pulls.Mergeability(ctx, f.deps, f.gitDir, res.PullRequest.IssueID)
350
351 if err := pulls.Merge(ctx, f.deps, pulls.MergeParams{
352 PRID: res.PullRequest.IssueID, ActorUserID: f.userID,
353 GitDir: f.gitDir, Method: "squash",
354 }); err != nil {
355 t.Fatalf("Merge: %v", err)
356 }
357
358 // The pre-existing issue (#1) should now be closed.
359 got, err := iq.GetIssueByID(ctx, f.pool, issue.ID)
360 if err != nil {
361 t.Fatalf("GetIssueByID: %v", err)
362 }
363 if got.State != issuesdb.IssueStateClosed {
364 t.Errorf("linked issue state: got %s, want closed", got.State)
365 }
366 }
367