@@ -13,10 +13,12 @@ import ( |
| 13 | 13 | "strings" |
| 14 | 14 | "testing" |
| 15 | 15 | "testing/fstest" |
| 16 | + "time" |
| 16 | 17 | |
| 17 | 18 | "github.com/go-chi/chi/v5" |
| 18 | 19 | "github.com/jackc/pgx/v5/pgxpool" |
| 19 | 20 | |
| 21 | + "github.com/tenseleyFlow/shithub/internal/billing" |
| 20 | 22 | orgsdb "github.com/tenseleyFlow/shithub/internal/orgs/sqlc" |
| 21 | 23 | "github.com/tenseleyFlow/shithub/internal/testing/dbtest" |
| 22 | 24 | orgsh "github.com/tenseleyFlow/shithub/internal/web/handlers/orgs" |
@@ -117,6 +119,93 @@ func TestTeamCreateBlocksSecretTeamsWithoutEntitlement(t *testing.T) { |
| 117 | 119 | } |
| 118 | 120 | } |
| 119 | 121 | |
| 122 | +func TestSecretTeamAddMemberRequiresEntitlementButRemoveAllowed(t *testing.T) { |
| 123 | + t.Parallel() |
| 124 | + ctx := context.Background() |
| 125 | + pool := dbtest.NewTestDB(t) |
| 126 | + ownerID := insertOrgAvatarUser(t, pool, "owner") |
| 127 | + memberID := insertOrgAvatarUser(t, pool, "member") |
| 128 | + orgID := insertOrgAvatarOrg(t, pool, ownerID, "acme") |
| 129 | + if _, err := pool.Exec(ctx, `INSERT INTO org_members (org_id, user_id, role) VALUES ($1, $2, 'member')`, orgID, memberID); err != nil { |
| 130 | + t.Fatalf("insert org member: %v", err) |
| 131 | + } |
| 132 | + teamID := insertTeamForTest(t, pool, orgID, "security", "Security", "secret") |
| 133 | + |
| 134 | + form := url.Values{"user_id": {strconv.FormatInt(memberID, 10)}, "role": {"member"}} |
| 135 | + body, status, location := performTeamsRequest(t, pool, middleware.CurrentUser{ID: ownerID, Username: "owner"}, http.MethodPost, "/acme/teams/security/members", form) |
| 136 | + if status != http.StatusSeeOther { |
| 137 | + t.Fatalf("free add status=%d body=%s", status, body) |
| 138 | + } |
| 139 | + if location != "/acme/teams/security?notice=secret-teams-upgrade" { |
| 140 | + t.Fatalf("free add redirect=%q", location) |
| 141 | + } |
| 142 | + assertTeamMemberCount(t, pool, teamID, 0) |
| 143 | + |
| 144 | + if _, err := pool.Exec(ctx, `INSERT INTO team_members (team_id, user_id, role) VALUES ($1, $2, 'member')`, teamID, memberID); err != nil { |
| 145 | + t.Fatalf("seed team member: %v", err) |
| 146 | + } |
| 147 | + assertTeamMemberCount(t, pool, teamID, 1) |
| 148 | + |
| 149 | + remove := url.Values{ |
| 150 | + "user_id": {strconv.FormatInt(memberID, 10)}, |
| 151 | + "action": {"remove"}, |
| 152 | + } |
| 153 | + body, status, _ = performTeamsRequest(t, pool, middleware.CurrentUser{ID: ownerID, Username: "owner"}, http.MethodPost, "/acme/teams/security/members", remove) |
| 154 | + if status != http.StatusSeeOther { |
| 155 | + t.Fatalf("remove status=%d body=%s", status, body) |
| 156 | + } |
| 157 | + assertTeamMemberCount(t, pool, teamID, 0) |
| 158 | + |
| 159 | + activateTeamPlanForTest(t, pool, orgID) |
| 160 | + body, status, _ = performTeamsRequest(t, pool, middleware.CurrentUser{ID: ownerID, Username: "owner"}, http.MethodPost, "/acme/teams/security/members", form) |
| 161 | + if status != http.StatusSeeOther { |
| 162 | + t.Fatalf("team add status=%d body=%s", status, body) |
| 163 | + } |
| 164 | + assertTeamMemberCount(t, pool, teamID, 1) |
| 165 | +} |
| 166 | + |
| 167 | +func TestSecretTeamRepoGrantRequiresEntitlementButRevokeAllowed(t *testing.T) { |
| 168 | + t.Parallel() |
| 169 | + ctx := context.Background() |
| 170 | + pool := dbtest.NewTestDB(t) |
| 171 | + ownerID := insertOrgAvatarUser(t, pool, "owner") |
| 172 | + orgID := insertOrgAvatarOrg(t, pool, ownerID, "acme") |
| 173 | + teamID := insertTeamForTest(t, pool, orgID, "security", "Security", "secret") |
| 174 | + repoID := insertTeamRepoForTest(t, pool, orgID, "private-repo") |
| 175 | + |
| 176 | + form := url.Values{"repo_id": {strconv.FormatInt(repoID, 10)}, "role": {"write"}} |
| 177 | + body, status, location := performTeamsRequest(t, pool, middleware.CurrentUser{ID: ownerID, Username: "owner"}, http.MethodPost, "/acme/teams/security/repos", form) |
| 178 | + if status != http.StatusSeeOther { |
| 179 | + t.Fatalf("free grant status=%d body=%s", status, body) |
| 180 | + } |
| 181 | + if location != "/acme/teams/security?notice=secret-teams-upgrade" { |
| 182 | + t.Fatalf("free grant redirect=%q", location) |
| 183 | + } |
| 184 | + assertTeamRepoGrantCount(t, pool, teamID, 0) |
| 185 | + |
| 186 | + if _, err := pool.Exec(ctx, `INSERT INTO team_repo_access (team_id, repo_id, role) VALUES ($1, $2, 'write')`, teamID, repoID); err != nil { |
| 187 | + t.Fatalf("seed team repo access: %v", err) |
| 188 | + } |
| 189 | + assertTeamRepoGrantCount(t, pool, teamID, 1) |
| 190 | + |
| 191 | + remove := url.Values{ |
| 192 | + "repo_id": {strconv.FormatInt(repoID, 10)}, |
| 193 | + "action": {"remove"}, |
| 194 | + } |
| 195 | + body, status, _ = performTeamsRequest(t, pool, middleware.CurrentUser{ID: ownerID, Username: "owner"}, http.MethodPost, "/acme/teams/security/repos", remove) |
| 196 | + if status != http.StatusSeeOther { |
| 197 | + t.Fatalf("revoke status=%d body=%s", status, body) |
| 198 | + } |
| 199 | + assertTeamRepoGrantCount(t, pool, teamID, 0) |
| 200 | + |
| 201 | + activateTeamPlanForTest(t, pool, orgID) |
| 202 | + body, status, _ = performTeamsRequest(t, pool, middleware.CurrentUser{ID: ownerID, Username: "owner"}, http.MethodPost, "/acme/teams/security/repos", form) |
| 203 | + if status != http.StatusSeeOther { |
| 204 | + t.Fatalf("team grant status=%d body=%s", status, body) |
| 205 | + } |
| 206 | + assertTeamRepoGrantCount(t, pool, teamID, 1) |
| 207 | +} |
| 208 | + |
| 120 | 209 | func performTeamsListRequest(t *testing.T, pool *pgxpool.Pool, viewer middleware.CurrentUser, target string) (string, int, string) { |
| 121 | 210 | return performTeamsRequest(t, pool, viewer, http.MethodGet, target, nil) |
| 122 | 211 | } |
@@ -178,3 +267,57 @@ func insertTeamForTest(t *testing.T, db orgsdb.DBTX, orgID int64, slug, displayN |
| 178 | 267 | } |
| 179 | 268 | return id |
| 180 | 269 | } |
| 270 | + |
| 271 | +func insertTeamRepoForTest(t *testing.T, db orgsdb.DBTX, orgID int64, name string) int64 { |
| 272 | + t.Helper() |
| 273 | + var id int64 |
| 274 | + if err := db.QueryRow(context.Background(), |
| 275 | + `INSERT INTO repos (owner_org_id, name, visibility, default_branch) |
| 276 | + VALUES ($1, $2, 'private', 'trunk') |
| 277 | + RETURNING id`, |
| 278 | + orgID, name, |
| 279 | + ).Scan(&id); err != nil { |
| 280 | + t.Fatalf("insert repo: %v", err) |
| 281 | + } |
| 282 | + return id |
| 283 | +} |
| 284 | + |
| 285 | +func activateTeamPlanForTest(t *testing.T, pool *pgxpool.Pool, orgID int64) { |
| 286 | + t.Helper() |
| 287 | + now := time.Now().UTC().Truncate(time.Second) |
| 288 | + _, err := billing.ApplySubscriptionSnapshot(context.Background(), billing.Deps{Pool: pool}, billing.SubscriptionSnapshot{ |
| 289 | + OrgID: orgID, |
| 290 | + Plan: billing.PlanTeam, |
| 291 | + Status: billing.SubscriptionStatusActive, |
| 292 | + StripeSubscriptionID: "sub_teams_" + strconv.FormatInt(orgID, 10), |
| 293 | + StripeSubscriptionItemID: "si_teams_" + strconv.FormatInt(orgID, 10), |
| 294 | + CurrentPeriodStart: now, |
| 295 | + CurrentPeriodEnd: now.Add(30 * 24 * time.Hour), |
| 296 | + LastWebhookEventID: "evt_teams_" + strconv.FormatInt(orgID, 10), |
| 297 | + }) |
| 298 | + if err != nil { |
| 299 | + t.Fatalf("activate team plan: %v", err) |
| 300 | + } |
| 301 | +} |
| 302 | + |
| 303 | +func assertTeamMemberCount(t *testing.T, db orgsdb.DBTX, teamID int64, want int) { |
| 304 | + t.Helper() |
| 305 | + var count int |
| 306 | + if err := db.QueryRow(context.Background(), `SELECT count(*) FROM team_members WHERE team_id = $1`, teamID).Scan(&count); err != nil { |
| 307 | + t.Fatalf("count team members: %v", err) |
| 308 | + } |
| 309 | + if count != want { |
| 310 | + t.Fatalf("team member count=%d, want %d", count, want) |
| 311 | + } |
| 312 | +} |
| 313 | + |
| 314 | +func assertTeamRepoGrantCount(t *testing.T, db orgsdb.DBTX, teamID int64, want int) { |
| 315 | + t.Helper() |
| 316 | + var count int |
| 317 | + if err := db.QueryRow(context.Background(), `SELECT count(*) FROM team_repo_access WHERE team_id = $1`, teamID).Scan(&count); err != nil { |
| 318 | + t.Fatalf("count team repo grants: %v", err) |
| 319 | + } |
| 320 | + if count != want { |
| 321 | + t.Fatalf("team repo grant count=%d, want %d", count, want) |
| 322 | + } |
| 323 | +} |