Go · 13493 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 package api_test
4
5 import (
6 "bytes"
7 "context"
8 "encoding/json"
9 "net/http"
10 "net/http/httptest"
11 "os"
12 "os/exec"
13 "path/filepath"
14 "strings"
15 "testing"
16
17 "github.com/jackc/pgx/v5/pgxpool"
18
19 "github.com/tenseleyFlow/shithub/internal/auth/audit"
20 "github.com/tenseleyFlow/shithub/internal/auth/pat"
21 "github.com/tenseleyFlow/shithub/internal/auth/throttle"
22 "github.com/tenseleyFlow/shithub/internal/repos"
23 "github.com/tenseleyFlow/shithub/internal/testing/dbtest"
24 )
25
26 type apiPull struct {
27 ID int64 `json:"id"`
28 Number int64 `json:"number"`
29 Title string `json:"title"`
30 Body string `json:"body"`
31 State string `json:"state"`
32 Draft bool `json:"draft"`
33 BaseRef string `json:"base_ref"`
34 HeadRef string `json:"head_ref"`
35 BaseOID string `json:"base_oid"`
36 HeadOID string `json:"head_oid"`
37 MergeableState string `json:"mergeable_state"`
38 Merged bool `json:"merged"`
39 MergeCommit string `json:"merge_commit_sha"`
40 MergeMethod string `json:"merge_method"`
41 MergedAt string `json:"merged_at"`
42 AuthorID int64 `json:"author_id"`
43 }
44
45 // gitCmdAPI is the test-side git shell wrapper — every invocation runs
46 // against a t.TempDir path the test set up.
47 //
48 //nolint:gosec
49 func gitCmdAPI(args ...string) *exec.Cmd { return exec.Command("git", args...) }
50
51 // commitOnRepoBranch lands a commit on `branch` of the bare repo at
52 // gitDir via a temp worktree. Branch is created if missing.
53 func commitOnRepoBranch(t *testing.T, gitDir, branch, msg, file, contents string) string {
54 t.Helper()
55 wt := t.TempDir()
56 addArgs := []string{"-C", gitDir, "worktree", "add"}
57 if _, err := gitCmdAPI("-C", gitDir, "show-ref", "--verify", "refs/heads/"+branch).CombinedOutput(); err != nil {
58 addArgs = append(addArgs, "-b", branch, wt)
59 } else {
60 addArgs = append(addArgs, wt, branch)
61 }
62 if out, err := gitCmdAPI(addArgs...).CombinedOutput(); err != nil {
63 t.Fatalf("worktree add %s: %v (%s)", branch, err, out)
64 }
65 defer func() {
66 _ = gitCmdAPI("-C", gitDir, "worktree", "remove", "--force", wt).Run()
67 }()
68
69 if err := os.WriteFile(filepath.Join(wt, file), []byte(contents), 0o644); err != nil { //nolint:gosec
70 t.Fatalf("write %s: %v", file, err)
71 }
72 for _, args := range [][]string{
73 {"-C", wt, "config", "user.name", "Alice"},
74 {"-C", wt, "config", "user.email", "alice@example.com"},
75 {"-C", wt, "add", "."},
76 {"-C", wt, "commit", "-m", msg},
77 } {
78 if out, err := gitCmdAPI(args...).CombinedOutput(); err != nil {
79 t.Fatalf("%v: %v (%s)", args, err, out)
80 }
81 }
82 out, err := gitCmdAPI("-C", wt, "rev-parse", "HEAD").Output()
83 if err != nil {
84 t.Fatalf("rev-parse HEAD: %v", err)
85 }
86 return strings.TrimSpace(string(out))
87 }
88
89 // seedPullsEnv builds the full PR-test environment: pool, router, owner
90 // user, an `alice/demo` repo initialized through repos.Create, a
91 // `trunk` commit, a `feature` commit, and a PAT scoped to repo:write.
92 // gitDir is returned so individual tests can add more commits when
93 // needed (e.g. to dirty the mergeable_state).
94 func seedPullsEnv(t *testing.T, ownerUsername string) (pool *pgxpool.Pool, router http.Handler, userID, repoID int64, token, gitDir string) {
95 t.Helper()
96 pool = dbtest.NewTestDB(t)
97 router, rfs := newReposAPIRouter(t, pool)
98 userID = seedRepoCreatorUser(t, pool, ownerUsername)
99 token = mintRunnerAPIPAT(t, pool, userID, string(pat.ScopeRepoWrite))
100
101 res, err := repos.Create(context.Background(), repos.Deps{
102 Pool: pool,
103 RepoFS: rfs,
104 Audit: audit.NewRecorder(),
105 Limiter: throttle.NewLimiter(),
106 }, repos.Params{
107 ActorUserID: userID,
108 OwnerUserID: userID,
109 OwnerUsername: ownerUsername,
110 Name: "demo",
111 Description: "demo",
112 Visibility: "public",
113 })
114 if err != nil {
115 t.Fatalf("repos.Create: %v", err)
116 }
117 repoID = res.Repo.ID
118 gitDir, err = rfs.RepoPath(ownerUsername, "demo")
119 if err != nil {
120 t.Fatalf("RepoFS.RepoPath: %v", err)
121 }
122 commitOnRepoBranch(t, gitDir, "trunk", "init", "README.md", "hi\n")
123 commitOnRepoBranch(t, gitDir, "feature", "add foo", "foo.txt", "foo\n")
124 return pool, router, userID, repoID, token, gitDir
125 }
126
127 func TestPulls_CreateAndGet(t *testing.T) {
128 _, router, _, _, token, _ := seedPullsEnv(t, "alice")
129
130 body, _ := json.Marshal(map[string]any{
131 "title": "wire up foo", "body": "first cut",
132 "base": "trunk", "head": "feature",
133 })
134 req := httptest.NewRequest(http.MethodPost, "/api/v1/repos/alice/demo/pulls", bytes.NewReader(body))
135 req.Header.Set("Authorization", "Bearer "+token)
136 rr := httptest.NewRecorder()
137 router.ServeHTTP(rr, req)
138 if rr.Code != http.StatusCreated {
139 t.Fatalf("create: %d; body=%s", rr.Code, rr.Body.String())
140 }
141 var created apiPull
142 if err := json.Unmarshal(rr.Body.Bytes(), &created); err != nil {
143 t.Fatalf("decode: %v", err)
144 }
145 if created.Title != "wire up foo" || created.BaseRef != "trunk" || created.HeadRef != "feature" {
146 t.Errorf("shape: %+v", created)
147 }
148 if created.State != "open" || created.Merged {
149 t.Errorf("state: %+v", created)
150 }
151
152 req = httptest.NewRequest(http.MethodGet, "/api/v1/repos/alice/demo/pulls/1", nil)
153 req.Header.Set("Authorization", "Bearer "+token)
154 rr = httptest.NewRecorder()
155 router.ServeHTTP(rr, req)
156 if rr.Code != http.StatusOK {
157 t.Fatalf("get: %d; body=%s", rr.Code, rr.Body.String())
158 }
159 }
160
161 func TestPulls_CreateRejectsSameBranch(t *testing.T) {
162 _, router, _, _, token, _ := seedPullsEnv(t, "alice")
163
164 body, _ := json.Marshal(map[string]any{"title": "self", "base": "trunk", "head": "trunk"})
165 req := httptest.NewRequest(http.MethodPost, "/api/v1/repos/alice/demo/pulls", bytes.NewReader(body))
166 req.Header.Set("Authorization", "Bearer "+token)
167 rr := httptest.NewRecorder()
168 router.ServeHTTP(rr, req)
169 if rr.Code != http.StatusUnprocessableEntity {
170 t.Fatalf("status: got %d, want 422; body=%s", rr.Code, rr.Body.String())
171 }
172 }
173
174 func TestPulls_CreateRejectsMissingHead(t *testing.T) {
175 _, router, _, _, token, _ := seedPullsEnv(t, "alice")
176
177 body, _ := json.Marshal(map[string]any{"title": "ghost", "base": "trunk", "head": "no-such-branch"})
178 req := httptest.NewRequest(http.MethodPost, "/api/v1/repos/alice/demo/pulls", bytes.NewReader(body))
179 req.Header.Set("Authorization", "Bearer "+token)
180 rr := httptest.NewRecorder()
181 router.ServeHTTP(rr, req)
182 if rr.Code != http.StatusUnprocessableEntity {
183 t.Fatalf("status: got %d, want 422; body=%s", rr.Code, rr.Body.String())
184 }
185 }
186
187 func TestPulls_CreateRequiresRepoWriteScope(t *testing.T) {
188 pool, router, userID, _, _, _ := seedPullsEnv(t, "alice")
189 readOnly := mintRunnerAPIPAT(t, pool, userID, string(pat.ScopeRepoRead))
190
191 body, _ := json.Marshal(map[string]any{"title": "x", "base": "trunk", "head": "feature"})
192 req := httptest.NewRequest(http.MethodPost, "/api/v1/repos/alice/demo/pulls", bytes.NewReader(body))
193 req.Header.Set("Authorization", "Bearer "+readOnly)
194 rr := httptest.NewRecorder()
195 router.ServeHTTP(rr, req)
196 if rr.Code != http.StatusForbidden {
197 t.Fatalf("status: got %d, want 403; body=%s", rr.Code, rr.Body.String())
198 }
199 }
200
201 func TestPulls_PatchTitleBody(t *testing.T) {
202 _, router, _, _, token, _ := seedPullsEnv(t, "alice")
203
204 openPullFor(t, router, token, "alice", "demo")
205
206 patch, _ := json.Marshal(map[string]any{"title": "renamed", "body": "renamed body"})
207 req := httptest.NewRequest(http.MethodPatch, "/api/v1/repos/alice/demo/pulls/1", bytes.NewReader(patch))
208 req.Header.Set("Authorization", "Bearer "+token)
209 rr := httptest.NewRecorder()
210 router.ServeHTTP(rr, req)
211 if rr.Code != http.StatusOK {
212 t.Fatalf("patch: %d; body=%s", rr.Code, rr.Body.String())
213 }
214 var updated apiPull
215 if err := json.Unmarshal(rr.Body.Bytes(), &updated); err != nil {
216 t.Fatalf("decode: %v", err)
217 }
218 if updated.Title != "renamed" || updated.Body != "renamed body" {
219 t.Errorf("patch shape: %+v", updated)
220 }
221 }
222
223 func TestPulls_PatchNonAuthorForbidden(t *testing.T) {
224 pool, router, _, _, tokenAlice, _ := seedPullsEnv(t, "alice")
225 openPullFor(t, router, tokenAlice, "alice", "demo")
226
227 bobID := seedRepoCreatorUser(t, pool, "bob")
228 tokenBob := mintRunnerAPIPAT(t, pool, bobID, string(pat.ScopeRepoWrite))
229
230 patch, _ := json.Marshal(map[string]any{"title": "hijack"})
231 req := httptest.NewRequest(http.MethodPatch, "/api/v1/repos/alice/demo/pulls/1", bytes.NewReader(patch))
232 req.Header.Set("Authorization", "Bearer "+tokenBob)
233 rr := httptest.NewRecorder()
234 router.ServeHTTP(rr, req)
235 if rr.Code != http.StatusForbidden {
236 t.Fatalf("status: got %d, want 403; body=%s", rr.Code, rr.Body.String())
237 }
238 }
239
240 func TestPulls_PatchDraftToReady(t *testing.T) {
241 _, router, _, _, token, _ := seedPullsEnv(t, "alice")
242 body, _ := json.Marshal(map[string]any{
243 "title": "draft pr", "base": "trunk", "head": "feature", "draft": true,
244 })
245 req := httptest.NewRequest(http.MethodPost, "/api/v1/repos/alice/demo/pulls", bytes.NewReader(body))
246 req.Header.Set("Authorization", "Bearer "+token)
247 rr := httptest.NewRecorder()
248 router.ServeHTTP(rr, req)
249 if rr.Code != http.StatusCreated {
250 t.Fatalf("draft create: %d; body=%s", rr.Code, rr.Body.String())
251 }
252
253 patch, _ := json.Marshal(map[string]any{"draft": false})
254 req = httptest.NewRequest(http.MethodPatch, "/api/v1/repos/alice/demo/pulls/1", bytes.NewReader(patch))
255 req.Header.Set("Authorization", "Bearer "+token)
256 rr = httptest.NewRecorder()
257 router.ServeHTTP(rr, req)
258 if rr.Code != http.StatusOK {
259 t.Fatalf("flip draft: %d; body=%s", rr.Code, rr.Body.String())
260 }
261 var updated apiPull
262 _ = json.Unmarshal(rr.Body.Bytes(), &updated)
263 if updated.Draft {
264 t.Errorf("expected draft=false; got %+v", updated)
265 }
266 }
267
268 func TestPulls_ListFiltersByState(t *testing.T) {
269 _, router, _, _, token, _ := seedPullsEnv(t, "alice")
270 openPullFor(t, router, token, "alice", "demo")
271
272 req := httptest.NewRequest(http.MethodGet, "/api/v1/repos/alice/demo/pulls?state=open", nil)
273 req.Header.Set("Authorization", "Bearer "+token)
274 rr := httptest.NewRecorder()
275 router.ServeHTTP(rr, req)
276 if rr.Code != http.StatusOK {
277 t.Fatalf("list: %d", rr.Code)
278 }
279 var listed []apiPull
280 if err := json.Unmarshal(rr.Body.Bytes(), &listed); err != nil {
281 t.Fatalf("decode: %v", err)
282 }
283 if len(listed) != 1 {
284 t.Errorf("open count: got %d, want 1", len(listed))
285 }
286
287 req = httptest.NewRequest(http.MethodGet, "/api/v1/repos/alice/demo/pulls?state=closed", nil)
288 req.Header.Set("Authorization", "Bearer "+token)
289 rr = httptest.NewRecorder()
290 router.ServeHTTP(rr, req)
291 if rr.Code != http.StatusOK {
292 t.Fatalf("closed list: %d", rr.Code)
293 }
294 if err := json.Unmarshal(rr.Body.Bytes(), &listed); err != nil {
295 t.Fatalf("decode closed: %v", err)
296 }
297 if len(listed) != 0 {
298 t.Errorf("closed count: got %d, want 0", len(listed))
299 }
300 }
301
302 func TestPulls_CommitsAndFilesListed(t *testing.T) {
303 _, router, _, _, token, _ := seedPullsEnv(t, "alice")
304 openPullFor(t, router, token, "alice", "demo")
305
306 req := httptest.NewRequest(http.MethodGet, "/api/v1/repos/alice/demo/pulls/1/commits", nil)
307 req.Header.Set("Authorization", "Bearer "+token)
308 rr := httptest.NewRecorder()
309 router.ServeHTTP(rr, req)
310 if rr.Code != http.StatusOK {
311 t.Fatalf("commits: %d; body=%s", rr.Code, rr.Body.String())
312 }
313 var commits []map[string]any
314 if err := json.Unmarshal(rr.Body.Bytes(), &commits); err != nil {
315 t.Fatalf("decode commits: %v", err)
316 }
317 if len(commits) == 0 {
318 t.Errorf("expected at least one commit on feature; got %+v", commits)
319 }
320
321 req = httptest.NewRequest(http.MethodGet, "/api/v1/repos/alice/demo/pulls/1/files", nil)
322 req.Header.Set("Authorization", "Bearer "+token)
323 rr = httptest.NewRecorder()
324 router.ServeHTTP(rr, req)
325 if rr.Code != http.StatusOK {
326 t.Fatalf("files: %d; body=%s", rr.Code, rr.Body.String())
327 }
328 var files []map[string]any
329 if err := json.Unmarshal(rr.Body.Bytes(), &files); err != nil {
330 t.Fatalf("decode files: %v", err)
331 }
332 if len(files) == 0 {
333 t.Errorf("expected at least one changed file; got %+v", files)
334 }
335 }
336
337 func TestPulls_GetReturns404ForNonPRNumber(t *testing.T) {
338 _, router, _, _, token, _ := seedPullsEnv(t, "alice")
339 body, _ := json.Marshal(map[string]any{"title": "plain issue"})
340 req := httptest.NewRequest(http.MethodPost, "/api/v1/repos/alice/demo/issues", bytes.NewReader(body))
341 req.Header.Set("Authorization", "Bearer "+token)
342 rr := httptest.NewRecorder()
343 router.ServeHTTP(rr, req)
344 if rr.Code != http.StatusCreated {
345 t.Fatalf("seed issue: %d; body=%s", rr.Code, rr.Body.String())
346 }
347
348 // Issue #1 is an issue, not a pull request — /pulls/1 must 404.
349 req = httptest.NewRequest(http.MethodGet, "/api/v1/repos/alice/demo/pulls/1", nil)
350 req.Header.Set("Authorization", "Bearer "+token)
351 rr = httptest.NewRecorder()
352 router.ServeHTTP(rr, req)
353 if rr.Code != http.StatusNotFound {
354 t.Fatalf("status: got %d, want 404; body=%s", rr.Code, rr.Body.String())
355 }
356 }
357
358 // openPullFor creates a default `trunk` <- `feature` PR on the supplied
359 // repo. Fails the test on non-201 so callers can keep their bodies
360 // short.
361 func openPullFor(t *testing.T, router http.Handler, token, owner, repo string) apiPull {
362 t.Helper()
363 body, _ := json.Marshal(map[string]any{"title": "default", "base": "trunk", "head": "feature"})
364 req := httptest.NewRequest(http.MethodPost, "/api/v1/repos/"+owner+"/"+repo+"/pulls", bytes.NewReader(body))
365 req.Header.Set("Authorization", "Bearer "+token)
366 rr := httptest.NewRecorder()
367 router.ServeHTTP(rr, req)
368 if rr.Code != http.StatusCreated {
369 t.Fatalf("openPullFor: %d; body=%s", rr.Code, rr.Body.String())
370 }
371 var out apiPull
372 _ = json.Unmarshal(rr.Body.Bytes(), &out)
373 return out
374 }
375