| 1 | // SPDX-License-Identifier: AGPL-3.0-or-later |
| 2 | |
| 3 | package api_test |
| 4 | |
| 5 | import ( |
| 6 | "context" |
| 7 | "encoding/json" |
| 8 | "fmt" |
| 9 | "net/http" |
| 10 | "net/http/httptest" |
| 11 | "strings" |
| 12 | "testing" |
| 13 | |
| 14 | "github.com/jackc/pgx/v5/pgtype" |
| 15 | "github.com/jackc/pgx/v5/pgxpool" |
| 16 | |
| 17 | "github.com/tenseleyFlow/shithub/internal/auth/audit" |
| 18 | "github.com/tenseleyFlow/shithub/internal/auth/pat" |
| 19 | "github.com/tenseleyFlow/shithub/internal/auth/throttle" |
| 20 | "github.com/tenseleyFlow/shithub/internal/infra/storage" |
| 21 | "github.com/tenseleyFlow/shithub/internal/repos" |
| 22 | reposdb "github.com/tenseleyFlow/shithub/internal/repos/sqlc" |
| 23 | usersdb "github.com/tenseleyFlow/shithub/internal/users/sqlc" |
| 24 | ) |
| 25 | |
| 26 | // apiRuleset mirrors handlers/api.rulesetResponse for decoding. |
| 27 | type apiRuleset struct { |
| 28 | ID int64 `json:"id"` |
| 29 | Name string `json:"name"` |
| 30 | Target string `json:"target"` |
| 31 | SourceType string `json:"source_type"` |
| 32 | Source string `json:"source"` |
| 33 | Enforcement string `json:"enforcement"` |
| 34 | Conditions apiRulesetConditions `json:"conditions"` |
| 35 | Rules []apiRulesetRule `json:"rules"` |
| 36 | } |
| 37 | |
| 38 | type apiRulesetConditions struct { |
| 39 | RefName apiRulesetRefName `json:"ref_name"` |
| 40 | } |
| 41 | |
| 42 | type apiRulesetRefName struct { |
| 43 | Include []string `json:"include"` |
| 44 | Exclude []string `json:"exclude"` |
| 45 | } |
| 46 | |
| 47 | type apiRulesetRule struct { |
| 48 | Type string `json:"type"` |
| 49 | Parameters map[string]any `json:"parameters,omitempty"` |
| 50 | } |
| 51 | |
| 52 | type rulesetsEnv struct { |
| 53 | pool *pgxpool.Pool |
| 54 | router http.Handler |
| 55 | rfs *storage.RepoFS |
| 56 | token string |
| 57 | owner string |
| 58 | repo string |
| 59 | } |
| 60 | |
| 61 | func newRulesetsEnv(t *testing.T, ownerUsername string) rulesetsEnv { |
| 62 | t.Helper() |
| 63 | pool, router, rfs, token, owner, repoName := seedBranchesEnv(t, ownerUsername) |
| 64 | return rulesetsEnv{pool: pool, router: router, rfs: rfs, token: token, owner: owner, repo: repoName} |
| 65 | } |
| 66 | |
| 67 | func (e rulesetsEnv) get(t *testing.T, path string) *httptest.ResponseRecorder { |
| 68 | t.Helper() |
| 69 | req := httptest.NewRequest(http.MethodGet, path, nil) |
| 70 | req.Header.Set("Authorization", "Bearer "+e.token) |
| 71 | rr := httptest.NewRecorder() |
| 72 | e.router.ServeHTTP(rr, req) |
| 73 | return rr |
| 74 | } |
| 75 | |
| 76 | // repoID looks the repo row up via the existing repos REST endpoint; |
| 77 | // keeps the test independent of internal package wiring. |
| 78 | func (e rulesetsEnv) repoID(t *testing.T) int64 { |
| 79 | t.Helper() |
| 80 | rr := e.get(t, fmt.Sprintf("/api/v1/repos/%s/%s", e.owner, e.repo)) |
| 81 | if rr.Code != http.StatusOK { |
| 82 | t.Fatalf("repo lookup status %d; body=%s", rr.Code, rr.Body.String()) |
| 83 | } |
| 84 | var got struct { |
| 85 | ID int64 `json:"id"` |
| 86 | } |
| 87 | if err := json.Unmarshal(rr.Body.Bytes(), &got); err != nil { |
| 88 | t.Fatalf("decode repo id: %v", err) |
| 89 | } |
| 90 | return got.ID |
| 91 | } |
| 92 | |
| 93 | // seedRule upserts a protection rule and (optionally) layers review |
| 94 | // settings on top. Returns the rule ID. |
| 95 | func seedRule(t *testing.T, pool *pgxpool.Pool, p reposdb.UpsertBranchProtectionRuleParams, reviewCount int32, requireCodeOwner bool) int64 { |
| 96 | t.Helper() |
| 97 | rq := reposdb.New() |
| 98 | id, err := rq.UpsertBranchProtectionRule(context.Background(), pool, p) |
| 99 | if err != nil { |
| 100 | t.Fatalf("UpsertBranchProtectionRule: %v", err) |
| 101 | } |
| 102 | if reviewCount > 0 || requireCodeOwner { |
| 103 | if err := rq.UpdateBranchProtectionReviewSettings(context.Background(), pool, reposdb.UpdateBranchProtectionReviewSettingsParams{ |
| 104 | ID: id, |
| 105 | RequiredReviewCount: reviewCount, |
| 106 | RequireCodeOwnerReview: requireCodeOwner, |
| 107 | }); err != nil { |
| 108 | t.Fatalf("UpdateBranchProtectionReviewSettings: %v", err) |
| 109 | } |
| 110 | } |
| 111 | return id |
| 112 | } |
| 113 | |
| 114 | func TestRulesets_ListEmpty(t *testing.T) { |
| 115 | env := newRulesetsEnv(t, "alice") |
| 116 | rr := env.get(t, fmt.Sprintf("/api/v1/repos/%s/%s/rulesets", env.owner, env.repo)) |
| 117 | if rr.Code != http.StatusOK { |
| 118 | t.Fatalf("status: got %d, want 200; body=%s", rr.Code, rr.Body.String()) |
| 119 | } |
| 120 | var listed []apiRuleset |
| 121 | if err := json.Unmarshal(rr.Body.Bytes(), &listed); err != nil { |
| 122 | t.Fatalf("decode: %v", err) |
| 123 | } |
| 124 | if len(listed) != 0 { |
| 125 | t.Errorf("expected empty list, got %d", len(listed)) |
| 126 | } |
| 127 | } |
| 128 | |
| 129 | func TestRulesets_ListProjectsRule(t *testing.T) { |
| 130 | env := newRulesetsEnv(t, "alice") |
| 131 | id := seedRule(t, env.pool, reposdb.UpsertBranchProtectionRuleParams{ |
| 132 | RepoID: env.repoID(t), |
| 133 | Pattern: "trunk", |
| 134 | PreventForcePush: true, |
| 135 | PreventDeletion: true, |
| 136 | RequirePrForPush: false, |
| 137 | AllowedPusherUserIds: []int64{}, |
| 138 | CreatedByUserID: pgtype.Int8{Valid: false}, |
| 139 | }, 2, true) |
| 140 | |
| 141 | rr := env.get(t, fmt.Sprintf("/api/v1/repos/%s/%s/rulesets", env.owner, env.repo)) |
| 142 | if rr.Code != http.StatusOK { |
| 143 | t.Fatalf("status: got %d, want 200; body=%s", rr.Code, rr.Body.String()) |
| 144 | } |
| 145 | var listed []apiRuleset |
| 146 | if err := json.Unmarshal(rr.Body.Bytes(), &listed); err != nil { |
| 147 | t.Fatalf("decode: %v", err) |
| 148 | } |
| 149 | if len(listed) != 1 { |
| 150 | t.Fatalf("len: got %d, want 1; payload=%+v", len(listed), listed) |
| 151 | } |
| 152 | rs := listed[0] |
| 153 | if rs.ID != id { |
| 154 | t.Errorf("id: got %d, want %d", rs.ID, id) |
| 155 | } |
| 156 | if rs.Name != "Pattern: trunk" { |
| 157 | t.Errorf("name: %q", rs.Name) |
| 158 | } |
| 159 | if rs.Target != "branch" || rs.SourceType != "Repository" || rs.Enforcement != "active" { |
| 160 | t.Errorf("envelope fields: %+v", rs) |
| 161 | } |
| 162 | wantSource := strings.ToLower(env.owner) + "/" + env.repo |
| 163 | if rs.Source != wantSource { |
| 164 | t.Errorf("source: got %q, want %q", rs.Source, wantSource) |
| 165 | } |
| 166 | if len(rs.Conditions.RefName.Include) != 1 || rs.Conditions.RefName.Include[0] != "refs/heads/trunk" { |
| 167 | t.Errorf("conditions.ref_name.include: %+v", rs.Conditions.RefName.Include) |
| 168 | } |
| 169 | |
| 170 | have := map[string]apiRulesetRule{} |
| 171 | for _, r := range rs.Rules { |
| 172 | have[r.Type] = r |
| 173 | } |
| 174 | if _, ok := have["non_fast_forward"]; !ok { |
| 175 | t.Errorf("missing non_fast_forward rule; rules=%+v", rs.Rules) |
| 176 | } |
| 177 | if _, ok := have["deletion"]; !ok { |
| 178 | t.Errorf("missing deletion rule; rules=%+v", rs.Rules) |
| 179 | } |
| 180 | pr, ok := have["pull_request"] |
| 181 | if !ok { |
| 182 | t.Fatalf("missing pull_request rule; rules=%+v", rs.Rules) |
| 183 | } |
| 184 | if v, _ := pr.Parameters["required_approving_review_count"].(float64); int(v) != 2 { |
| 185 | t.Errorf("pull_request.required_approving_review_count: %v", pr.Parameters["required_approving_review_count"]) |
| 186 | } |
| 187 | if v, _ := pr.Parameters["require_code_owner_review"].(bool); !v { |
| 188 | t.Errorf("pull_request.require_code_owner_review: %v", pr.Parameters["require_code_owner_review"]) |
| 189 | } |
| 190 | } |
| 191 | |
| 192 | func TestRulesets_GetSingle(t *testing.T) { |
| 193 | env := newRulesetsEnv(t, "alice") |
| 194 | id := seedRule(t, env.pool, reposdb.UpsertBranchProtectionRuleParams{ |
| 195 | RepoID: env.repoID(t), Pattern: "release/*", |
| 196 | PreventForcePush: true, |
| 197 | AllowedPusherUserIds: []int64{}, |
| 198 | CreatedByUserID: pgtype.Int8{Valid: false}, |
| 199 | }, 0, false) |
| 200 | |
| 201 | rr := env.get(t, fmt.Sprintf("/api/v1/repos/%s/%s/rulesets/%d", env.owner, env.repo, id)) |
| 202 | if rr.Code != http.StatusOK { |
| 203 | t.Fatalf("status: got %d, want 200; body=%s", rr.Code, rr.Body.String()) |
| 204 | } |
| 205 | var got apiRuleset |
| 206 | if err := json.Unmarshal(rr.Body.Bytes(), &got); err != nil { |
| 207 | t.Fatalf("decode: %v", err) |
| 208 | } |
| 209 | if got.ID != id || got.Name != "Pattern: release/*" { |
| 210 | t.Errorf("shape: %+v", got) |
| 211 | } |
| 212 | } |
| 213 | |
| 214 | func TestRulesets_GetUnknownReturns404(t *testing.T) { |
| 215 | env := newRulesetsEnv(t, "alice") |
| 216 | rr := env.get(t, fmt.Sprintf("/api/v1/repos/%s/%s/rulesets/9999", env.owner, env.repo)) |
| 217 | if rr.Code != http.StatusNotFound { |
| 218 | t.Fatalf("status: got %d, want 404; body=%s", rr.Code, rr.Body.String()) |
| 219 | } |
| 220 | } |
| 221 | |
| 222 | func TestRulesets_GetCrossRepoLeak404(t *testing.T) { |
| 223 | env := newRulesetsEnv(t, "alice") |
| 224 | // Spin up a second repo for the same owner and put a rule on it. |
| 225 | // Hitting the first repo's URL with the second repo's rule id |
| 226 | // must 404 — same status as "doesn't exist" to keep existence |
| 227 | // non-discoverable across repo boundaries. |
| 228 | otherID := seedSecondRepoForOwner(t, env, "demo2") |
| 229 | id := seedRule(t, env.pool, reposdb.UpsertBranchProtectionRuleParams{ |
| 230 | RepoID: otherID, Pattern: "trunk", |
| 231 | PreventDeletion: true, |
| 232 | AllowedPusherUserIds: []int64{}, |
| 233 | CreatedByUserID: pgtype.Int8{Valid: false}, |
| 234 | }, 0, false) |
| 235 | |
| 236 | rr := env.get(t, fmt.Sprintf("/api/v1/repos/%s/%s/rulesets/%d", env.owner, env.repo, id)) |
| 237 | if rr.Code != http.StatusNotFound { |
| 238 | t.Fatalf("cross-repo status: got %d, want 404; body=%s", rr.Code, rr.Body.String()) |
| 239 | } |
| 240 | } |
| 241 | |
| 242 | func TestRulesets_RulesForBranchListsAllMatches(t *testing.T) { |
| 243 | env := newRulesetsEnv(t, "alice") |
| 244 | repoID := env.repoID(t) |
| 245 | idWild := seedRule(t, env.pool, reposdb.UpsertBranchProtectionRuleParams{ |
| 246 | RepoID: repoID, Pattern: "release/*", |
| 247 | PreventForcePush: true, |
| 248 | AllowedPusherUserIds: []int64{}, |
| 249 | CreatedByUserID: pgtype.Int8{Valid: false}, |
| 250 | }, 0, false) |
| 251 | idExact := seedRule(t, env.pool, reposdb.UpsertBranchProtectionRuleParams{ |
| 252 | RepoID: repoID, Pattern: "release/v1.0", |
| 253 | PreventDeletion: true, |
| 254 | AllowedPusherUserIds: []int64{}, |
| 255 | CreatedByUserID: pgtype.Int8{Valid: false}, |
| 256 | }, 0, false) |
| 257 | |
| 258 | rr := env.get(t, fmt.Sprintf("/api/v1/repos/%s/%s/rules/branches/release/v1.0", env.owner, env.repo)) |
| 259 | if rr.Code != http.StatusOK { |
| 260 | t.Fatalf("status: got %d, want 200; body=%s", rr.Code, rr.Body.String()) |
| 261 | } |
| 262 | var listed []apiRuleset |
| 263 | if err := json.Unmarshal(rr.Body.Bytes(), &listed); err != nil { |
| 264 | t.Fatalf("decode: %v", err) |
| 265 | } |
| 266 | if len(listed) != 2 { |
| 267 | t.Fatalf("len: got %d, want 2 (wildcard + exact); payload=%+v", len(listed), listed) |
| 268 | } |
| 269 | seen := map[int64]bool{} |
| 270 | for _, rs := range listed { |
| 271 | seen[rs.ID] = true |
| 272 | } |
| 273 | if !seen[idWild] || !seen[idExact] { |
| 274 | t.Errorf("missing expected rule id; seen=%v", seen) |
| 275 | } |
| 276 | } |
| 277 | |
| 278 | func TestRulesets_RulesForBranchNoMatch(t *testing.T) { |
| 279 | env := newRulesetsEnv(t, "alice") |
| 280 | seedRule(t, env.pool, reposdb.UpsertBranchProtectionRuleParams{ |
| 281 | RepoID: env.repoID(t), Pattern: "release/*", |
| 282 | PreventForcePush: true, |
| 283 | AllowedPusherUserIds: []int64{}, |
| 284 | CreatedByUserID: pgtype.Int8{Valid: false}, |
| 285 | }, 0, false) |
| 286 | |
| 287 | rr := env.get(t, fmt.Sprintf("/api/v1/repos/%s/%s/rules/branches/feature/x", env.owner, env.repo)) |
| 288 | if rr.Code != http.StatusOK { |
| 289 | t.Fatalf("status: got %d; body=%s", rr.Code, rr.Body.String()) |
| 290 | } |
| 291 | var listed []apiRuleset |
| 292 | if err := json.Unmarshal(rr.Body.Bytes(), &listed); err != nil { |
| 293 | t.Fatalf("decode: %v", err) |
| 294 | } |
| 295 | if len(listed) != 0 { |
| 296 | t.Errorf("expected no rules; got %+v", listed) |
| 297 | } |
| 298 | } |
| 299 | |
| 300 | func TestRulesets_RequiresReadScope(t *testing.T) { |
| 301 | env := newRulesetsEnv(t, "alice") |
| 302 | // Mint a user:read-only token for a different actor; rulesets |
| 303 | // list requires repo:read so this must 403. |
| 304 | otherID := seedRepoCreatorUser(t, env.pool, "carol") |
| 305 | wrongScope := mintRunnerAPIPAT(t, env.pool, otherID, string(pat.ScopeUserRead)) |
| 306 | req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/api/v1/repos/%s/%s/rulesets", env.owner, env.repo), nil) |
| 307 | req.Header.Set("Authorization", "Bearer "+wrongScope) |
| 308 | rr := httptest.NewRecorder() |
| 309 | env.router.ServeHTTP(rr, req) |
| 310 | if rr.Code != http.StatusForbidden { |
| 311 | t.Fatalf("status: got %d, want 403; body=%s", rr.Code, rr.Body.String()) |
| 312 | } |
| 313 | } |
| 314 | |
| 315 | // seedSecondRepoForOwner creates a second repo under the existing |
| 316 | // owner; returns the new repo's id. Used by the cross-repo leak |
| 317 | // test to confirm rulesets/{id} doesn't disclose rules belonging |
| 318 | // to a different repo under the same owner. |
| 319 | func seedSecondRepoForOwner(t *testing.T, env rulesetsEnv, name string) int64 { |
| 320 | t.Helper() |
| 321 | // Look up the existing owner — seedRepoCreatorUser is NOT |
| 322 | // idempotent (it inserts) so we resolve by username instead. |
| 323 | user, err := usersdb.New().GetUserByUsername(context.Background(), env.pool, env.owner) |
| 324 | if err != nil { |
| 325 | t.Fatalf("GetUserByUsername %q: %v", env.owner, err) |
| 326 | } |
| 327 | creatorID := user.ID |
| 328 | row, err := repos.Create(context.Background(), repos.Deps{ |
| 329 | Pool: env.pool, |
| 330 | RepoFS: env.rfs, |
| 331 | Audit: audit.NewRecorder(), |
| 332 | Limiter: throttle.NewLimiter(), |
| 333 | }, repos.Params{ |
| 334 | ActorUserID: creatorID, |
| 335 | OwnerUserID: creatorID, |
| 336 | OwnerUsername: env.owner, |
| 337 | Name: name, |
| 338 | Description: "secondary repo", |
| 339 | Visibility: "public", |
| 340 | }) |
| 341 | if err != nil { |
| 342 | t.Fatalf("repos.Create %q: %v", name, err) |
| 343 | } |
| 344 | return row.Repo.ID |
| 345 | } |
| 346 |