| 1 | // SPDX-License-Identifier: AGPL-3.0-or-later |
| 2 | |
| 3 | package orgs_test |
| 4 | |
| 5 | import ( |
| 6 | "context" |
| 7 | "io" |
| 8 | "log/slog" |
| 9 | "net/http" |
| 10 | "net/http/httptest" |
| 11 | "net/url" |
| 12 | "strconv" |
| 13 | "strings" |
| 14 | "testing" |
| 15 | "testing/fstest" |
| 16 | "time" |
| 17 | |
| 18 | "github.com/go-chi/chi/v5" |
| 19 | "github.com/jackc/pgx/v5/pgxpool" |
| 20 | |
| 21 | "github.com/tenseleyFlow/shithub/internal/billing" |
| 22 | orgsdb "github.com/tenseleyFlow/shithub/internal/orgs/sqlc" |
| 23 | "github.com/tenseleyFlow/shithub/internal/testing/dbtest" |
| 24 | orgsh "github.com/tenseleyFlow/shithub/internal/web/handlers/orgs" |
| 25 | "github.com/tenseleyFlow/shithub/internal/web/middleware" |
| 26 | "github.com/tenseleyFlow/shithub/internal/web/render" |
| 27 | ) |
| 28 | |
| 29 | func TestTeamsListRequiresOrgMemberAndFiltersSecretTeams(t *testing.T) { |
| 30 | t.Parallel() |
| 31 | ctx := context.Background() |
| 32 | pool := dbtest.NewTestDB(t) |
| 33 | ownerID := insertOrgAvatarUser(t, pool, "owner") |
| 34 | memberID := insertOrgAvatarUser(t, pool, "member") |
| 35 | outsiderID := insertOrgAvatarUser(t, pool, "outsider") |
| 36 | orgID := insertOrgAvatarOrg(t, pool, ownerID, "acme") |
| 37 | if _, err := pool.Exec(ctx, `INSERT INTO org_members (org_id, user_id, role) VALUES ($1, $2, 'member')`, orgID, memberID); err != nil { |
| 38 | t.Fatalf("insert org member: %v", err) |
| 39 | } |
| 40 | visibleTeamID := insertTeamForTest(t, pool, orgID, "engineering", "Engineering", "visible") |
| 41 | insertTeamForTest(t, pool, orgID, "security", "Security", "secret") |
| 42 | if _, err := pool.Exec(ctx, `INSERT INTO team_members (team_id, user_id, role) VALUES ($1, $2, 'member')`, visibleTeamID, memberID); err != nil { |
| 43 | t.Fatalf("insert team member: %v", err) |
| 44 | } |
| 45 | |
| 46 | memberBody, memberStatus, _ := performTeamsListRequest(t, pool, middleware.CurrentUser{ID: memberID, Username: "member"}, "/acme/teams") |
| 47 | if memberStatus != http.StatusOK { |
| 48 | t.Fatalf("member status=%d body=%s", memberStatus, memberBody) |
| 49 | } |
| 50 | if !strings.Contains(memberBody, "TEAM=engineering:Engineering:1:0") { |
| 51 | t.Fatalf("expected visible team with counts, got %s", memberBody) |
| 52 | } |
| 53 | if strings.Contains(memberBody, "security") { |
| 54 | t.Fatalf("secret team leaked to non-team member: %s", memberBody) |
| 55 | } |
| 56 | |
| 57 | outsiderBody, outsiderStatus, _ := performTeamsListRequest(t, pool, middleware.CurrentUser{ID: outsiderID, Username: "outsider"}, "/acme/teams") |
| 58 | if outsiderStatus != http.StatusNotFound { |
| 59 | t.Fatalf("outsider status=%d body=%s", outsiderStatus, outsiderBody) |
| 60 | } |
| 61 | |
| 62 | _, anonymousStatus, anonymousLocation := performTeamsListRequest(t, pool, middleware.CurrentUser{}, "/acme/teams") |
| 63 | if anonymousStatus != http.StatusSeeOther { |
| 64 | t.Fatalf("anonymous status=%d", anonymousStatus) |
| 65 | } |
| 66 | if !strings.HasPrefix(anonymousLocation, "/login?next=") { |
| 67 | t.Fatalf("anonymous redirect=%q", anonymousLocation) |
| 68 | } |
| 69 | } |
| 70 | |
| 71 | func TestTeamMemberAddRejectsNonOrgUsers(t *testing.T) { |
| 72 | t.Parallel() |
| 73 | ctx := context.Background() |
| 74 | pool := dbtest.NewTestDB(t) |
| 75 | ownerID := insertOrgAvatarUser(t, pool, "owner") |
| 76 | outsiderID := insertOrgAvatarUser(t, pool, "outsider") |
| 77 | orgID := insertOrgAvatarOrg(t, pool, ownerID, "acme") |
| 78 | teamID := insertTeamForTest(t, pool, orgID, "engineering", "Engineering", "visible") |
| 79 | |
| 80 | form := url.Values{"user_id": {strconv.FormatInt(outsiderID, 10)}, "role": {"member"}} |
| 81 | body, status, _ := performTeamsRequest(t, pool, middleware.CurrentUser{ID: ownerID, Username: "owner"}, http.MethodPost, "/acme/teams/engineering/members", form) |
| 82 | if status != http.StatusBadRequest { |
| 83 | t.Fatalf("status=%d body=%s", status, body) |
| 84 | } |
| 85 | var count int |
| 86 | if err := pool.QueryRow(ctx, `SELECT count(*) FROM team_members WHERE team_id = $1`, teamID).Scan(&count); err != nil { |
| 87 | t.Fatalf("count team members: %v", err) |
| 88 | } |
| 89 | if count != 0 { |
| 90 | t.Fatalf("expected no team member insert, got %d", count) |
| 91 | } |
| 92 | } |
| 93 | |
| 94 | func TestTeamCreateBlocksSecretTeamsWithoutEntitlement(t *testing.T) { |
| 95 | t.Parallel() |
| 96 | ctx := context.Background() |
| 97 | pool := dbtest.NewTestDB(t) |
| 98 | ownerID := insertOrgAvatarUser(t, pool, "owner") |
| 99 | orgID := insertOrgAvatarOrg(t, pool, ownerID, "acme") |
| 100 | |
| 101 | form := url.Values{ |
| 102 | "display_name": {"Security"}, |
| 103 | "slug": {"security"}, |
| 104 | "privacy": {"secret"}, |
| 105 | } |
| 106 | body, status, location := performTeamsRequest(t, pool, middleware.CurrentUser{ID: ownerID, Username: "owner"}, http.MethodPost, "/acme/teams", form) |
| 107 | if status != http.StatusSeeOther { |
| 108 | t.Fatalf("status=%d body=%s", status, body) |
| 109 | } |
| 110 | if location != "/acme/teams?notice=secret-teams-upgrade" { |
| 111 | t.Fatalf("redirect=%q", location) |
| 112 | } |
| 113 | var count int |
| 114 | if err := pool.QueryRow(ctx, `SELECT count(*) FROM teams WHERE org_id = $1`, orgID).Scan(&count); err != nil { |
| 115 | t.Fatalf("count teams: %v", err) |
| 116 | } |
| 117 | if count != 0 { |
| 118 | t.Fatalf("expected no secret team insert, got %d", count) |
| 119 | } |
| 120 | } |
| 121 | |
| 122 | func TestSecretTeamAddMemberRequiresEntitlementButRemoveAllowed(t *testing.T) { |
| 123 | t.Parallel() |
| 124 | ctx := context.Background() |
| 125 | pool := dbtest.NewTestDB(t) |
| 126 | ownerID := insertOrgAvatarUser(t, pool, "owner") |
| 127 | memberID := insertOrgAvatarUser(t, pool, "member") |
| 128 | orgID := insertOrgAvatarOrg(t, pool, ownerID, "acme") |
| 129 | if _, err := pool.Exec(ctx, `INSERT INTO org_members (org_id, user_id, role) VALUES ($1, $2, 'member')`, orgID, memberID); err != nil { |
| 130 | t.Fatalf("insert org member: %v", err) |
| 131 | } |
| 132 | teamID := insertTeamForTest(t, pool, orgID, "security", "Security", "secret") |
| 133 | |
| 134 | form := url.Values{"user_id": {strconv.FormatInt(memberID, 10)}, "role": {"member"}} |
| 135 | body, status, location := performTeamsRequest(t, pool, middleware.CurrentUser{ID: ownerID, Username: "owner"}, http.MethodPost, "/acme/teams/security/members", form) |
| 136 | if status != http.StatusSeeOther { |
| 137 | t.Fatalf("free add status=%d body=%s", status, body) |
| 138 | } |
| 139 | if location != "/acme/teams/security?notice=secret-teams-upgrade" { |
| 140 | t.Fatalf("free add redirect=%q", location) |
| 141 | } |
| 142 | assertTeamMemberCount(t, pool, teamID, 0) |
| 143 | |
| 144 | if _, err := pool.Exec(ctx, `INSERT INTO team_members (team_id, user_id, role) VALUES ($1, $2, 'member')`, teamID, memberID); err != nil { |
| 145 | t.Fatalf("seed team member: %v", err) |
| 146 | } |
| 147 | assertTeamMemberCount(t, pool, teamID, 1) |
| 148 | |
| 149 | remove := url.Values{ |
| 150 | "user_id": {strconv.FormatInt(memberID, 10)}, |
| 151 | "action": {"remove"}, |
| 152 | } |
| 153 | body, status, _ = performTeamsRequest(t, pool, middleware.CurrentUser{ID: ownerID, Username: "owner"}, http.MethodPost, "/acme/teams/security/members", remove) |
| 154 | if status != http.StatusSeeOther { |
| 155 | t.Fatalf("remove status=%d body=%s", status, body) |
| 156 | } |
| 157 | assertTeamMemberCount(t, pool, teamID, 0) |
| 158 | |
| 159 | activateTeamPlanForTest(t, pool, orgID) |
| 160 | body, status, _ = performTeamsRequest(t, pool, middleware.CurrentUser{ID: ownerID, Username: "owner"}, http.MethodPost, "/acme/teams/security/members", form) |
| 161 | if status != http.StatusSeeOther { |
| 162 | t.Fatalf("team add status=%d body=%s", status, body) |
| 163 | } |
| 164 | assertTeamMemberCount(t, pool, teamID, 1) |
| 165 | } |
| 166 | |
| 167 | func TestSecretTeamRepoGrantRequiresEntitlementButRevokeAllowed(t *testing.T) { |
| 168 | t.Parallel() |
| 169 | ctx := context.Background() |
| 170 | pool := dbtest.NewTestDB(t) |
| 171 | ownerID := insertOrgAvatarUser(t, pool, "owner") |
| 172 | orgID := insertOrgAvatarOrg(t, pool, ownerID, "acme") |
| 173 | teamID := insertTeamForTest(t, pool, orgID, "security", "Security", "secret") |
| 174 | repoID := insertTeamRepoForTest(t, pool, orgID, "private-repo") |
| 175 | |
| 176 | form := url.Values{"repo_id": {strconv.FormatInt(repoID, 10)}, "role": {"write"}} |
| 177 | body, status, location := performTeamsRequest(t, pool, middleware.CurrentUser{ID: ownerID, Username: "owner"}, http.MethodPost, "/acme/teams/security/repos", form) |
| 178 | if status != http.StatusSeeOther { |
| 179 | t.Fatalf("free grant status=%d body=%s", status, body) |
| 180 | } |
| 181 | if location != "/acme/teams/security?notice=secret-teams-upgrade" { |
| 182 | t.Fatalf("free grant redirect=%q", location) |
| 183 | } |
| 184 | assertTeamRepoGrantCount(t, pool, teamID, 0) |
| 185 | |
| 186 | if _, err := pool.Exec(ctx, `INSERT INTO team_repo_access (team_id, repo_id, role) VALUES ($1, $2, 'write')`, teamID, repoID); err != nil { |
| 187 | t.Fatalf("seed team repo access: %v", err) |
| 188 | } |
| 189 | assertTeamRepoGrantCount(t, pool, teamID, 1) |
| 190 | |
| 191 | remove := url.Values{ |
| 192 | "repo_id": {strconv.FormatInt(repoID, 10)}, |
| 193 | "action": {"remove"}, |
| 194 | } |
| 195 | body, status, _ = performTeamsRequest(t, pool, middleware.CurrentUser{ID: ownerID, Username: "owner"}, http.MethodPost, "/acme/teams/security/repos", remove) |
| 196 | if status != http.StatusSeeOther { |
| 197 | t.Fatalf("revoke status=%d body=%s", status, body) |
| 198 | } |
| 199 | assertTeamRepoGrantCount(t, pool, teamID, 0) |
| 200 | |
| 201 | activateTeamPlanForTest(t, pool, orgID) |
| 202 | body, status, _ = performTeamsRequest(t, pool, middleware.CurrentUser{ID: ownerID, Username: "owner"}, http.MethodPost, "/acme/teams/security/repos", form) |
| 203 | if status != http.StatusSeeOther { |
| 204 | t.Fatalf("team grant status=%d body=%s", status, body) |
| 205 | } |
| 206 | assertTeamRepoGrantCount(t, pool, teamID, 1) |
| 207 | } |
| 208 | |
| 209 | func performTeamsListRequest(t *testing.T, pool *pgxpool.Pool, viewer middleware.CurrentUser, target string) (string, int, string) { |
| 210 | return performTeamsRequest(t, pool, viewer, http.MethodGet, target, nil) |
| 211 | } |
| 212 | |
| 213 | func performTeamsRequest(t *testing.T, pool *pgxpool.Pool, viewer middleware.CurrentUser, method, target string, form url.Values) (string, int, string) { |
| 214 | t.Helper() |
| 215 | rr, err := render.New(fstest.MapFS{ |
| 216 | "_layout.html": {Data: []byte(`{{ define "layout" }}<html><body>{{ template "page" . }}</body></html>{{ end }}`)}, |
| 217 | "orgs/teams_list.html": {Data: []byte(`{{ define "page" }}ACTIVE={{ .ActiveOrgNav }} TOTAL={{ .TeamTotalCount }}{{ range .Teams }} TEAM={{ .Slug }}:{{ .DisplayName }}:{{ .MemberCount }}:{{ .RepoCount }}{{ end }}{{ end }}`)}, |
| 218 | "orgs/team_view.html": {Data: []byte(`{{ define "page" }}TEAM{{ end }}`)}, |
| 219 | "orgs/people.html": {Data: []byte(`{{ define "page" }}PEOPLE{{ end }}`)}, |
| 220 | "errors/400.html": {Data: []byte(`{{ define "page" }}400{{ end }}`)}, |
| 221 | "errors/403.html": {Data: []byte(`{{ define "page" }}403{{ end }}`)}, |
| 222 | "errors/404.html": {Data: []byte(`{{ define "page" }}404{{ end }}`)}, |
| 223 | "errors/500.html": {Data: []byte(`{{ define "page" }}500{{ end }}`)}, |
| 224 | }, render.Options{}) |
| 225 | if err != nil { |
| 226 | t.Fatalf("render.New: %v", err) |
| 227 | } |
| 228 | h, err := orgsh.New(orgsh.Deps{ |
| 229 | Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), |
| 230 | Render: rr, |
| 231 | Pool: pool, |
| 232 | }) |
| 233 | if err != nil { |
| 234 | t.Fatalf("orgsh.New: %v", err) |
| 235 | } |
| 236 | r := chi.NewRouter() |
| 237 | r.Use(func(next http.Handler) http.Handler { |
| 238 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| 239 | next.ServeHTTP(w, r.WithContext(middleware.WithCurrentUserForTest(r.Context(), viewer))) |
| 240 | }) |
| 241 | }) |
| 242 | h.MountOrgRoutes(r) |
| 243 | |
| 244 | var body io.Reader |
| 245 | if form != nil { |
| 246 | body = strings.NewReader(form.Encode()) |
| 247 | } |
| 248 | req := httptest.NewRequest(method, target, body) |
| 249 | if form != nil { |
| 250 | req.Header.Set("Content-Type", "application/x-www-form-urlencoded") |
| 251 | } |
| 252 | rec := httptest.NewRecorder() |
| 253 | r.ServeHTTP(rec, req) |
| 254 | return rec.Body.String(), rec.Code, rec.Header().Get("Location") |
| 255 | } |
| 256 | |
| 257 | func insertTeamForTest(t *testing.T, db orgsdb.DBTX, orgID int64, slug, displayName, privacy string) int64 { |
| 258 | t.Helper() |
| 259 | var id int64 |
| 260 | if err := db.QueryRow(context.Background(), |
| 261 | `INSERT INTO teams (org_id, slug, display_name, privacy) |
| 262 | VALUES ($1, $2, $3, $4) |
| 263 | RETURNING id`, |
| 264 | orgID, slug, displayName, privacy, |
| 265 | ).Scan(&id); err != nil { |
| 266 | t.Fatalf("insert team: %v", err) |
| 267 | } |
| 268 | return id |
| 269 | } |
| 270 | |
| 271 | func insertTeamRepoForTest(t *testing.T, db orgsdb.DBTX, orgID int64, name string) int64 { |
| 272 | t.Helper() |
| 273 | var id int64 |
| 274 | if err := db.QueryRow(context.Background(), |
| 275 | `INSERT INTO repos (owner_org_id, name, visibility, default_branch) |
| 276 | VALUES ($1, $2, 'private', 'trunk') |
| 277 | RETURNING id`, |
| 278 | orgID, name, |
| 279 | ).Scan(&id); err != nil { |
| 280 | t.Fatalf("insert repo: %v", err) |
| 281 | } |
| 282 | return id |
| 283 | } |
| 284 | |
| 285 | func activateTeamPlanForTest(t *testing.T, pool *pgxpool.Pool, orgID int64) { |
| 286 | t.Helper() |
| 287 | now := time.Now().UTC().Truncate(time.Second) |
| 288 | _, err := billing.ApplySubscriptionSnapshot(context.Background(), billing.Deps{Pool: pool}, billing.SubscriptionSnapshot{ |
| 289 | OrgID: orgID, |
| 290 | Plan: billing.PlanTeam, |
| 291 | Status: billing.SubscriptionStatusActive, |
| 292 | StripeSubscriptionID: "sub_teams_" + strconv.FormatInt(orgID, 10), |
| 293 | StripeSubscriptionItemID: "si_teams_" + strconv.FormatInt(orgID, 10), |
| 294 | CurrentPeriodStart: now, |
| 295 | CurrentPeriodEnd: now.Add(30 * 24 * time.Hour), |
| 296 | LastWebhookEventID: "evt_teams_" + strconv.FormatInt(orgID, 10), |
| 297 | }) |
| 298 | if err != nil { |
| 299 | t.Fatalf("activate team plan: %v", err) |
| 300 | } |
| 301 | } |
| 302 | |
| 303 | func assertTeamMemberCount(t *testing.T, db orgsdb.DBTX, teamID int64, want int) { |
| 304 | t.Helper() |
| 305 | var count int |
| 306 | if err := db.QueryRow(context.Background(), `SELECT count(*) FROM team_members WHERE team_id = $1`, teamID).Scan(&count); err != nil { |
| 307 | t.Fatalf("count team members: %v", err) |
| 308 | } |
| 309 | if count != want { |
| 310 | t.Fatalf("team member count=%d, want %d", count, want) |
| 311 | } |
| 312 | } |
| 313 | |
| 314 | func assertTeamRepoGrantCount(t *testing.T, db orgsdb.DBTX, teamID int64, want int) { |
| 315 | t.Helper() |
| 316 | var count int |
| 317 | if err := db.QueryRow(context.Background(), `SELECT count(*) FROM team_repo_access WHERE team_id = $1`, teamID).Scan(&count); err != nil { |
| 318 | t.Fatalf("count team repo grants: %v", err) |
| 319 | } |
| 320 | if count != want { |
| 321 | t.Fatalf("team repo grant count=%d, want %d", count, want) |
| 322 | } |
| 323 | } |
| 324 |