@@ -0,0 +1,345 @@ |
| 1 | +// SPDX-License-Identifier: AGPL-3.0-or-later |
| 2 | + |
| 3 | +package api_test |
| 4 | + |
| 5 | +import ( |
| 6 | + "context" |
| 7 | + "encoding/json" |
| 8 | + "fmt" |
| 9 | + "net/http" |
| 10 | + "net/http/httptest" |
| 11 | + "strings" |
| 12 | + "testing" |
| 13 | + |
| 14 | + "github.com/jackc/pgx/v5/pgtype" |
| 15 | + "github.com/jackc/pgx/v5/pgxpool" |
| 16 | + |
| 17 | + "github.com/tenseleyFlow/shithub/internal/auth/audit" |
| 18 | + "github.com/tenseleyFlow/shithub/internal/auth/pat" |
| 19 | + "github.com/tenseleyFlow/shithub/internal/auth/throttle" |
| 20 | + "github.com/tenseleyFlow/shithub/internal/infra/storage" |
| 21 | + "github.com/tenseleyFlow/shithub/internal/repos" |
| 22 | + reposdb "github.com/tenseleyFlow/shithub/internal/repos/sqlc" |
| 23 | + usersdb "github.com/tenseleyFlow/shithub/internal/users/sqlc" |
| 24 | +) |
| 25 | + |
| 26 | +// apiRuleset mirrors handlers/api.rulesetResponse for decoding. |
| 27 | +type apiRuleset struct { |
| 28 | + ID int64 `json:"id"` |
| 29 | + Name string `json:"name"` |
| 30 | + Target string `json:"target"` |
| 31 | + SourceType string `json:"source_type"` |
| 32 | + Source string `json:"source"` |
| 33 | + Enforcement string `json:"enforcement"` |
| 34 | + Conditions apiRulesetConditions `json:"conditions"` |
| 35 | + Rules []apiRulesetRule `json:"rules"` |
| 36 | +} |
| 37 | + |
| 38 | +type apiRulesetConditions struct { |
| 39 | + RefName apiRulesetRefName `json:"ref_name"` |
| 40 | +} |
| 41 | + |
| 42 | +type apiRulesetRefName struct { |
| 43 | + Include []string `json:"include"` |
| 44 | + Exclude []string `json:"exclude"` |
| 45 | +} |
| 46 | + |
| 47 | +type apiRulesetRule struct { |
| 48 | + Type string `json:"type"` |
| 49 | + Parameters map[string]any `json:"parameters,omitempty"` |
| 50 | +} |
| 51 | + |
| 52 | +type rulesetsEnv struct { |
| 53 | + pool *pgxpool.Pool |
| 54 | + router http.Handler |
| 55 | + rfs *storage.RepoFS |
| 56 | + token string |
| 57 | + owner string |
| 58 | + repo string |
| 59 | +} |
| 60 | + |
| 61 | +func newRulesetsEnv(t *testing.T, ownerUsername string) rulesetsEnv { |
| 62 | + t.Helper() |
| 63 | + pool, router, rfs, token, owner, repoName := seedBranchesEnv(t, ownerUsername) |
| 64 | + return rulesetsEnv{pool: pool, router: router, rfs: rfs, token: token, owner: owner, repo: repoName} |
| 65 | +} |
| 66 | + |
| 67 | +func (e rulesetsEnv) get(t *testing.T, path string) *httptest.ResponseRecorder { |
| 68 | + t.Helper() |
| 69 | + req := httptest.NewRequest(http.MethodGet, path, nil) |
| 70 | + req.Header.Set("Authorization", "Bearer "+e.token) |
| 71 | + rr := httptest.NewRecorder() |
| 72 | + e.router.ServeHTTP(rr, req) |
| 73 | + return rr |
| 74 | +} |
| 75 | + |
| 76 | +// repoID looks the repo row up via the existing repos REST endpoint; |
| 77 | +// keeps the test independent of internal package wiring. |
| 78 | +func (e rulesetsEnv) repoID(t *testing.T) int64 { |
| 79 | + t.Helper() |
| 80 | + rr := e.get(t, fmt.Sprintf("/api/v1/repos/%s/%s", e.owner, e.repo)) |
| 81 | + if rr.Code != http.StatusOK { |
| 82 | + t.Fatalf("repo lookup status %d; body=%s", rr.Code, rr.Body.String()) |
| 83 | + } |
| 84 | + var got struct { |
| 85 | + ID int64 `json:"id"` |
| 86 | + } |
| 87 | + if err := json.Unmarshal(rr.Body.Bytes(), &got); err != nil { |
| 88 | + t.Fatalf("decode repo id: %v", err) |
| 89 | + } |
| 90 | + return got.ID |
| 91 | +} |
| 92 | + |
| 93 | +// seedRule upserts a protection rule and (optionally) layers review |
| 94 | +// settings on top. Returns the rule ID. |
| 95 | +func seedRule(t *testing.T, pool *pgxpool.Pool, p reposdb.UpsertBranchProtectionRuleParams, reviewCount int32, requireCodeOwner bool) int64 { |
| 96 | + t.Helper() |
| 97 | + rq := reposdb.New() |
| 98 | + id, err := rq.UpsertBranchProtectionRule(context.Background(), pool, p) |
| 99 | + if err != nil { |
| 100 | + t.Fatalf("UpsertBranchProtectionRule: %v", err) |
| 101 | + } |
| 102 | + if reviewCount > 0 || requireCodeOwner { |
| 103 | + if err := rq.UpdateBranchProtectionReviewSettings(context.Background(), pool, reposdb.UpdateBranchProtectionReviewSettingsParams{ |
| 104 | + ID: id, |
| 105 | + RequiredReviewCount: reviewCount, |
| 106 | + RequireCodeOwnerReview: requireCodeOwner, |
| 107 | + }); err != nil { |
| 108 | + t.Fatalf("UpdateBranchProtectionReviewSettings: %v", err) |
| 109 | + } |
| 110 | + } |
| 111 | + return id |
| 112 | +} |
| 113 | + |
| 114 | +func TestRulesets_ListEmpty(t *testing.T) { |
| 115 | + env := newRulesetsEnv(t, "alice") |
| 116 | + rr := env.get(t, fmt.Sprintf("/api/v1/repos/%s/%s/rulesets", env.owner, env.repo)) |
| 117 | + if rr.Code != http.StatusOK { |
| 118 | + t.Fatalf("status: got %d, want 200; body=%s", rr.Code, rr.Body.String()) |
| 119 | + } |
| 120 | + var listed []apiRuleset |
| 121 | + if err := json.Unmarshal(rr.Body.Bytes(), &listed); err != nil { |
| 122 | + t.Fatalf("decode: %v", err) |
| 123 | + } |
| 124 | + if len(listed) != 0 { |
| 125 | + t.Errorf("expected empty list, got %d", len(listed)) |
| 126 | + } |
| 127 | +} |
| 128 | + |
| 129 | +func TestRulesets_ListProjectsRule(t *testing.T) { |
| 130 | + env := newRulesetsEnv(t, "alice") |
| 131 | + id := seedRule(t, env.pool, reposdb.UpsertBranchProtectionRuleParams{ |
| 132 | + RepoID: env.repoID(t), |
| 133 | + Pattern: "trunk", |
| 134 | + PreventForcePush: true, |
| 135 | + PreventDeletion: true, |
| 136 | + RequirePrForPush: false, |
| 137 | + AllowedPusherUserIds: []int64{}, |
| 138 | + CreatedByUserID: pgtype.Int8{Valid: false}, |
| 139 | + }, 2, true) |
| 140 | + |
| 141 | + rr := env.get(t, fmt.Sprintf("/api/v1/repos/%s/%s/rulesets", env.owner, env.repo)) |
| 142 | + if rr.Code != http.StatusOK { |
| 143 | + t.Fatalf("status: got %d, want 200; body=%s", rr.Code, rr.Body.String()) |
| 144 | + } |
| 145 | + var listed []apiRuleset |
| 146 | + if err := json.Unmarshal(rr.Body.Bytes(), &listed); err != nil { |
| 147 | + t.Fatalf("decode: %v", err) |
| 148 | + } |
| 149 | + if len(listed) != 1 { |
| 150 | + t.Fatalf("len: got %d, want 1; payload=%+v", len(listed), listed) |
| 151 | + } |
| 152 | + rs := listed[0] |
| 153 | + if rs.ID != id { |
| 154 | + t.Errorf("id: got %d, want %d", rs.ID, id) |
| 155 | + } |
| 156 | + if rs.Name != "Pattern: trunk" { |
| 157 | + t.Errorf("name: %q", rs.Name) |
| 158 | + } |
| 159 | + if rs.Target != "branch" || rs.SourceType != "Repository" || rs.Enforcement != "active" { |
| 160 | + t.Errorf("envelope fields: %+v", rs) |
| 161 | + } |
| 162 | + wantSource := strings.ToLower(env.owner) + "/" + env.repo |
| 163 | + if rs.Source != wantSource { |
| 164 | + t.Errorf("source: got %q, want %q", rs.Source, wantSource) |
| 165 | + } |
| 166 | + if len(rs.Conditions.RefName.Include) != 1 || rs.Conditions.RefName.Include[0] != "refs/heads/trunk" { |
| 167 | + t.Errorf("conditions.ref_name.include: %+v", rs.Conditions.RefName.Include) |
| 168 | + } |
| 169 | + |
| 170 | + have := map[string]apiRulesetRule{} |
| 171 | + for _, r := range rs.Rules { |
| 172 | + have[r.Type] = r |
| 173 | + } |
| 174 | + if _, ok := have["non_fast_forward"]; !ok { |
| 175 | + t.Errorf("missing non_fast_forward rule; rules=%+v", rs.Rules) |
| 176 | + } |
| 177 | + if _, ok := have["deletion"]; !ok { |
| 178 | + t.Errorf("missing deletion rule; rules=%+v", rs.Rules) |
| 179 | + } |
| 180 | + pr, ok := have["pull_request"] |
| 181 | + if !ok { |
| 182 | + t.Fatalf("missing pull_request rule; rules=%+v", rs.Rules) |
| 183 | + } |
| 184 | + if v, _ := pr.Parameters["required_approving_review_count"].(float64); int(v) != 2 { |
| 185 | + t.Errorf("pull_request.required_approving_review_count: %v", pr.Parameters["required_approving_review_count"]) |
| 186 | + } |
| 187 | + if v, _ := pr.Parameters["require_code_owner_review"].(bool); !v { |
| 188 | + t.Errorf("pull_request.require_code_owner_review: %v", pr.Parameters["require_code_owner_review"]) |
| 189 | + } |
| 190 | +} |
| 191 | + |
| 192 | +func TestRulesets_GetSingle(t *testing.T) { |
| 193 | + env := newRulesetsEnv(t, "alice") |
| 194 | + id := seedRule(t, env.pool, reposdb.UpsertBranchProtectionRuleParams{ |
| 195 | + RepoID: env.repoID(t), Pattern: "release/*", |
| 196 | + PreventForcePush: true, |
| 197 | + AllowedPusherUserIds: []int64{}, |
| 198 | + CreatedByUserID: pgtype.Int8{Valid: false}, |
| 199 | + }, 0, false) |
| 200 | + |
| 201 | + rr := env.get(t, fmt.Sprintf("/api/v1/repos/%s/%s/rulesets/%d", env.owner, env.repo, id)) |
| 202 | + if rr.Code != http.StatusOK { |
| 203 | + t.Fatalf("status: got %d, want 200; body=%s", rr.Code, rr.Body.String()) |
| 204 | + } |
| 205 | + var got apiRuleset |
| 206 | + if err := json.Unmarshal(rr.Body.Bytes(), &got); err != nil { |
| 207 | + t.Fatalf("decode: %v", err) |
| 208 | + } |
| 209 | + if got.ID != id || got.Name != "Pattern: release/*" { |
| 210 | + t.Errorf("shape: %+v", got) |
| 211 | + } |
| 212 | +} |
| 213 | + |
| 214 | +func TestRulesets_GetUnknownReturns404(t *testing.T) { |
| 215 | + env := newRulesetsEnv(t, "alice") |
| 216 | + rr := env.get(t, fmt.Sprintf("/api/v1/repos/%s/%s/rulesets/9999", env.owner, env.repo)) |
| 217 | + if rr.Code != http.StatusNotFound { |
| 218 | + t.Fatalf("status: got %d, want 404; body=%s", rr.Code, rr.Body.String()) |
| 219 | + } |
| 220 | +} |
| 221 | + |
| 222 | +func TestRulesets_GetCrossRepoLeak404(t *testing.T) { |
| 223 | + env := newRulesetsEnv(t, "alice") |
| 224 | + // Spin up a second repo for the same owner and put a rule on it. |
| 225 | + // Hitting the first repo's URL with the second repo's rule id |
| 226 | + // must 404 — same status as "doesn't exist" to keep existence |
| 227 | + // non-discoverable across repo boundaries. |
| 228 | + otherID := seedSecondRepoForOwner(t, env, "demo2") |
| 229 | + id := seedRule(t, env.pool, reposdb.UpsertBranchProtectionRuleParams{ |
| 230 | + RepoID: otherID, Pattern: "trunk", |
| 231 | + PreventDeletion: true, |
| 232 | + AllowedPusherUserIds: []int64{}, |
| 233 | + CreatedByUserID: pgtype.Int8{Valid: false}, |
| 234 | + }, 0, false) |
| 235 | + |
| 236 | + rr := env.get(t, fmt.Sprintf("/api/v1/repos/%s/%s/rulesets/%d", env.owner, env.repo, id)) |
| 237 | + if rr.Code != http.StatusNotFound { |
| 238 | + t.Fatalf("cross-repo status: got %d, want 404; body=%s", rr.Code, rr.Body.String()) |
| 239 | + } |
| 240 | +} |
| 241 | + |
| 242 | +func TestRulesets_RulesForBranchListsAllMatches(t *testing.T) { |
| 243 | + env := newRulesetsEnv(t, "alice") |
| 244 | + repoID := env.repoID(t) |
| 245 | + idWild := seedRule(t, env.pool, reposdb.UpsertBranchProtectionRuleParams{ |
| 246 | + RepoID: repoID, Pattern: "release/*", |
| 247 | + PreventForcePush: true, |
| 248 | + AllowedPusherUserIds: []int64{}, |
| 249 | + CreatedByUserID: pgtype.Int8{Valid: false}, |
| 250 | + }, 0, false) |
| 251 | + idExact := seedRule(t, env.pool, reposdb.UpsertBranchProtectionRuleParams{ |
| 252 | + RepoID: repoID, Pattern: "release/v1.0", |
| 253 | + PreventDeletion: true, |
| 254 | + AllowedPusherUserIds: []int64{}, |
| 255 | + CreatedByUserID: pgtype.Int8{Valid: false}, |
| 256 | + }, 0, false) |
| 257 | + |
| 258 | + rr := env.get(t, fmt.Sprintf("/api/v1/repos/%s/%s/rules/branches/release/v1.0", env.owner, env.repo)) |
| 259 | + if rr.Code != http.StatusOK { |
| 260 | + t.Fatalf("status: got %d, want 200; body=%s", rr.Code, rr.Body.String()) |
| 261 | + } |
| 262 | + var listed []apiRuleset |
| 263 | + if err := json.Unmarshal(rr.Body.Bytes(), &listed); err != nil { |
| 264 | + t.Fatalf("decode: %v", err) |
| 265 | + } |
| 266 | + if len(listed) != 2 { |
| 267 | + t.Fatalf("len: got %d, want 2 (wildcard + exact); payload=%+v", len(listed), listed) |
| 268 | + } |
| 269 | + seen := map[int64]bool{} |
| 270 | + for _, rs := range listed { |
| 271 | + seen[rs.ID] = true |
| 272 | + } |
| 273 | + if !seen[idWild] || !seen[idExact] { |
| 274 | + t.Errorf("missing expected rule id; seen=%v", seen) |
| 275 | + } |
| 276 | +} |
| 277 | + |
| 278 | +func TestRulesets_RulesForBranchNoMatch(t *testing.T) { |
| 279 | + env := newRulesetsEnv(t, "alice") |
| 280 | + seedRule(t, env.pool, reposdb.UpsertBranchProtectionRuleParams{ |
| 281 | + RepoID: env.repoID(t), Pattern: "release/*", |
| 282 | + PreventForcePush: true, |
| 283 | + AllowedPusherUserIds: []int64{}, |
| 284 | + CreatedByUserID: pgtype.Int8{Valid: false}, |
| 285 | + }, 0, false) |
| 286 | + |
| 287 | + rr := env.get(t, fmt.Sprintf("/api/v1/repos/%s/%s/rules/branches/feature/x", env.owner, env.repo)) |
| 288 | + if rr.Code != http.StatusOK { |
| 289 | + t.Fatalf("status: got %d; body=%s", rr.Code, rr.Body.String()) |
| 290 | + } |
| 291 | + var listed []apiRuleset |
| 292 | + if err := json.Unmarshal(rr.Body.Bytes(), &listed); err != nil { |
| 293 | + t.Fatalf("decode: %v", err) |
| 294 | + } |
| 295 | + if len(listed) != 0 { |
| 296 | + t.Errorf("expected no rules; got %+v", listed) |
| 297 | + } |
| 298 | +} |
| 299 | + |
| 300 | +func TestRulesets_RequiresReadScope(t *testing.T) { |
| 301 | + env := newRulesetsEnv(t, "alice") |
| 302 | + // Mint a user:read-only token for a different actor; rulesets |
| 303 | + // list requires repo:read so this must 403. |
| 304 | + otherID := seedRepoCreatorUser(t, env.pool, "carol") |
| 305 | + wrongScope := mintRunnerAPIPAT(t, env.pool, otherID, string(pat.ScopeUserRead)) |
| 306 | + req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/api/v1/repos/%s/%s/rulesets", env.owner, env.repo), nil) |
| 307 | + req.Header.Set("Authorization", "Bearer "+wrongScope) |
| 308 | + rr := httptest.NewRecorder() |
| 309 | + env.router.ServeHTTP(rr, req) |
| 310 | + if rr.Code != http.StatusForbidden { |
| 311 | + t.Fatalf("status: got %d, want 403; body=%s", rr.Code, rr.Body.String()) |
| 312 | + } |
| 313 | +} |
| 314 | + |
| 315 | +// seedSecondRepoForOwner creates a second repo under the existing |
| 316 | +// owner; returns the new repo's id. Used by the cross-repo leak |
| 317 | +// test to confirm rulesets/{id} doesn't disclose rules belonging |
| 318 | +// to a different repo under the same owner. |
| 319 | +func seedSecondRepoForOwner(t *testing.T, env rulesetsEnv, name string) int64 { |
| 320 | + t.Helper() |
| 321 | + // Look up the existing owner — seedRepoCreatorUser is NOT |
| 322 | + // idempotent (it inserts) so we resolve by username instead. |
| 323 | + user, err := usersdb.New().GetUserByUsername(context.Background(), env.pool, env.owner) |
| 324 | + if err != nil { |
| 325 | + t.Fatalf("GetUserByUsername %q: %v", env.owner, err) |
| 326 | + } |
| 327 | + creatorID := user.ID |
| 328 | + row, err := repos.Create(context.Background(), repos.Deps{ |
| 329 | + Pool: env.pool, |
| 330 | + RepoFS: env.rfs, |
| 331 | + Audit: audit.NewRecorder(), |
| 332 | + Limiter: throttle.NewLimiter(), |
| 333 | + }, repos.Params{ |
| 334 | + ActorUserID: creatorID, |
| 335 | + OwnerUserID: creatorID, |
| 336 | + OwnerUsername: env.owner, |
| 337 | + Name: name, |
| 338 | + Description: "secondary repo", |
| 339 | + Visibility: "public", |
| 340 | + }) |
| 341 | + if err != nil { |
| 342 | + t.Fatalf("repos.Create %q: %v", name, err) |
| 343 | + } |
| 344 | + return row.Repo.ID |
| 345 | +} |