Go · 12905 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 package orgs_test
4
5 import (
6 "context"
7 "io"
8 "log/slog"
9 "net/http"
10 "net/http/httptest"
11 "net/url"
12 "strconv"
13 "strings"
14 "testing"
15 "testing/fstest"
16 "time"
17
18 "github.com/go-chi/chi/v5"
19 "github.com/jackc/pgx/v5/pgxpool"
20
21 "github.com/tenseleyFlow/shithub/internal/billing"
22 orgsdb "github.com/tenseleyFlow/shithub/internal/orgs/sqlc"
23 "github.com/tenseleyFlow/shithub/internal/testing/dbtest"
24 orgsh "github.com/tenseleyFlow/shithub/internal/web/handlers/orgs"
25 "github.com/tenseleyFlow/shithub/internal/web/middleware"
26 "github.com/tenseleyFlow/shithub/internal/web/render"
27 )
28
29 func TestTeamsListRequiresOrgMemberAndFiltersSecretTeams(t *testing.T) {
30 t.Parallel()
31 ctx := context.Background()
32 pool := dbtest.NewTestDB(t)
33 ownerID := insertOrgAvatarUser(t, pool, "owner")
34 memberID := insertOrgAvatarUser(t, pool, "member")
35 outsiderID := insertOrgAvatarUser(t, pool, "outsider")
36 orgID := insertOrgAvatarOrg(t, pool, ownerID, "acme")
37 if _, err := pool.Exec(ctx, `INSERT INTO org_members (org_id, user_id, role) VALUES ($1, $2, 'member')`, orgID, memberID); err != nil {
38 t.Fatalf("insert org member: %v", err)
39 }
40 visibleTeamID := insertTeamForTest(t, pool, orgID, "engineering", "Engineering", "visible")
41 insertTeamForTest(t, pool, orgID, "security", "Security", "secret")
42 if _, err := pool.Exec(ctx, `INSERT INTO team_members (team_id, user_id, role) VALUES ($1, $2, 'member')`, visibleTeamID, memberID); err != nil {
43 t.Fatalf("insert team member: %v", err)
44 }
45
46 memberBody, memberStatus, _ := performTeamsListRequest(t, pool, middleware.CurrentUser{ID: memberID, Username: "member"}, "/acme/teams")
47 if memberStatus != http.StatusOK {
48 t.Fatalf("member status=%d body=%s", memberStatus, memberBody)
49 }
50 if !strings.Contains(memberBody, "TEAM=engineering:Engineering:1:0") {
51 t.Fatalf("expected visible team with counts, got %s", memberBody)
52 }
53 if strings.Contains(memberBody, "security") {
54 t.Fatalf("secret team leaked to non-team member: %s", memberBody)
55 }
56
57 outsiderBody, outsiderStatus, _ := performTeamsListRequest(t, pool, middleware.CurrentUser{ID: outsiderID, Username: "outsider"}, "/acme/teams")
58 if outsiderStatus != http.StatusNotFound {
59 t.Fatalf("outsider status=%d body=%s", outsiderStatus, outsiderBody)
60 }
61
62 _, anonymousStatus, anonymousLocation := performTeamsListRequest(t, pool, middleware.CurrentUser{}, "/acme/teams")
63 if anonymousStatus != http.StatusSeeOther {
64 t.Fatalf("anonymous status=%d", anonymousStatus)
65 }
66 if !strings.HasPrefix(anonymousLocation, "/login?next=") {
67 t.Fatalf("anonymous redirect=%q", anonymousLocation)
68 }
69 }
70
71 func TestTeamMemberAddRejectsNonOrgUsers(t *testing.T) {
72 t.Parallel()
73 ctx := context.Background()
74 pool := dbtest.NewTestDB(t)
75 ownerID := insertOrgAvatarUser(t, pool, "owner")
76 outsiderID := insertOrgAvatarUser(t, pool, "outsider")
77 orgID := insertOrgAvatarOrg(t, pool, ownerID, "acme")
78 teamID := insertTeamForTest(t, pool, orgID, "engineering", "Engineering", "visible")
79
80 form := url.Values{"user_id": {strconv.FormatInt(outsiderID, 10)}, "role": {"member"}}
81 body, status, _ := performTeamsRequest(t, pool, middleware.CurrentUser{ID: ownerID, Username: "owner"}, http.MethodPost, "/acme/teams/engineering/members", form)
82 if status != http.StatusBadRequest {
83 t.Fatalf("status=%d body=%s", status, body)
84 }
85 var count int
86 if err := pool.QueryRow(ctx, `SELECT count(*) FROM team_members WHERE team_id = $1`, teamID).Scan(&count); err != nil {
87 t.Fatalf("count team members: %v", err)
88 }
89 if count != 0 {
90 t.Fatalf("expected no team member insert, got %d", count)
91 }
92 }
93
94 func TestTeamCreateBlocksSecretTeamsWithoutEntitlement(t *testing.T) {
95 t.Parallel()
96 ctx := context.Background()
97 pool := dbtest.NewTestDB(t)
98 ownerID := insertOrgAvatarUser(t, pool, "owner")
99 orgID := insertOrgAvatarOrg(t, pool, ownerID, "acme")
100
101 form := url.Values{
102 "display_name": {"Security"},
103 "slug": {"security"},
104 "privacy": {"secret"},
105 }
106 body, status, location := performTeamsRequest(t, pool, middleware.CurrentUser{ID: ownerID, Username: "owner"}, http.MethodPost, "/acme/teams", form)
107 if status != http.StatusSeeOther {
108 t.Fatalf("status=%d body=%s", status, body)
109 }
110 if location != "/acme/teams?notice=secret-teams-upgrade" {
111 t.Fatalf("redirect=%q", location)
112 }
113 var count int
114 if err := pool.QueryRow(ctx, `SELECT count(*) FROM teams WHERE org_id = $1`, orgID).Scan(&count); err != nil {
115 t.Fatalf("count teams: %v", err)
116 }
117 if count != 0 {
118 t.Fatalf("expected no secret team insert, got %d", count)
119 }
120 }
121
122 func TestSecretTeamAddMemberRequiresEntitlementButRemoveAllowed(t *testing.T) {
123 t.Parallel()
124 ctx := context.Background()
125 pool := dbtest.NewTestDB(t)
126 ownerID := insertOrgAvatarUser(t, pool, "owner")
127 memberID := insertOrgAvatarUser(t, pool, "member")
128 orgID := insertOrgAvatarOrg(t, pool, ownerID, "acme")
129 if _, err := pool.Exec(ctx, `INSERT INTO org_members (org_id, user_id, role) VALUES ($1, $2, 'member')`, orgID, memberID); err != nil {
130 t.Fatalf("insert org member: %v", err)
131 }
132 teamID := insertTeamForTest(t, pool, orgID, "security", "Security", "secret")
133
134 form := url.Values{"user_id": {strconv.FormatInt(memberID, 10)}, "role": {"member"}}
135 body, status, location := performTeamsRequest(t, pool, middleware.CurrentUser{ID: ownerID, Username: "owner"}, http.MethodPost, "/acme/teams/security/members", form)
136 if status != http.StatusSeeOther {
137 t.Fatalf("free add status=%d body=%s", status, body)
138 }
139 if location != "/acme/teams/security?notice=secret-teams-upgrade" {
140 t.Fatalf("free add redirect=%q", location)
141 }
142 assertTeamMemberCount(t, pool, teamID, 0)
143
144 if _, err := pool.Exec(ctx, `INSERT INTO team_members (team_id, user_id, role) VALUES ($1, $2, 'member')`, teamID, memberID); err != nil {
145 t.Fatalf("seed team member: %v", err)
146 }
147 assertTeamMemberCount(t, pool, teamID, 1)
148
149 remove := url.Values{
150 "user_id": {strconv.FormatInt(memberID, 10)},
151 "action": {"remove"},
152 }
153 body, status, _ = performTeamsRequest(t, pool, middleware.CurrentUser{ID: ownerID, Username: "owner"}, http.MethodPost, "/acme/teams/security/members", remove)
154 if status != http.StatusSeeOther {
155 t.Fatalf("remove status=%d body=%s", status, body)
156 }
157 assertTeamMemberCount(t, pool, teamID, 0)
158
159 activateTeamPlanForTest(t, pool, orgID)
160 body, status, _ = performTeamsRequest(t, pool, middleware.CurrentUser{ID: ownerID, Username: "owner"}, http.MethodPost, "/acme/teams/security/members", form)
161 if status != http.StatusSeeOther {
162 t.Fatalf("team add status=%d body=%s", status, body)
163 }
164 assertTeamMemberCount(t, pool, teamID, 1)
165 }
166
167 func TestSecretTeamRepoGrantRequiresEntitlementButRevokeAllowed(t *testing.T) {
168 t.Parallel()
169 ctx := context.Background()
170 pool := dbtest.NewTestDB(t)
171 ownerID := insertOrgAvatarUser(t, pool, "owner")
172 orgID := insertOrgAvatarOrg(t, pool, ownerID, "acme")
173 teamID := insertTeamForTest(t, pool, orgID, "security", "Security", "secret")
174 repoID := insertTeamRepoForTest(t, pool, orgID, "private-repo")
175
176 form := url.Values{"repo_id": {strconv.FormatInt(repoID, 10)}, "role": {"write"}}
177 body, status, location := performTeamsRequest(t, pool, middleware.CurrentUser{ID: ownerID, Username: "owner"}, http.MethodPost, "/acme/teams/security/repos", form)
178 if status != http.StatusSeeOther {
179 t.Fatalf("free grant status=%d body=%s", status, body)
180 }
181 if location != "/acme/teams/security?notice=secret-teams-upgrade" {
182 t.Fatalf("free grant redirect=%q", location)
183 }
184 assertTeamRepoGrantCount(t, pool, teamID, 0)
185
186 if _, err := pool.Exec(ctx, `INSERT INTO team_repo_access (team_id, repo_id, role) VALUES ($1, $2, 'write')`, teamID, repoID); err != nil {
187 t.Fatalf("seed team repo access: %v", err)
188 }
189 assertTeamRepoGrantCount(t, pool, teamID, 1)
190
191 remove := url.Values{
192 "repo_id": {strconv.FormatInt(repoID, 10)},
193 "action": {"remove"},
194 }
195 body, status, _ = performTeamsRequest(t, pool, middleware.CurrentUser{ID: ownerID, Username: "owner"}, http.MethodPost, "/acme/teams/security/repos", remove)
196 if status != http.StatusSeeOther {
197 t.Fatalf("revoke status=%d body=%s", status, body)
198 }
199 assertTeamRepoGrantCount(t, pool, teamID, 0)
200
201 activateTeamPlanForTest(t, pool, orgID)
202 body, status, _ = performTeamsRequest(t, pool, middleware.CurrentUser{ID: ownerID, Username: "owner"}, http.MethodPost, "/acme/teams/security/repos", form)
203 if status != http.StatusSeeOther {
204 t.Fatalf("team grant status=%d body=%s", status, body)
205 }
206 assertTeamRepoGrantCount(t, pool, teamID, 1)
207 }
208
209 func performTeamsListRequest(t *testing.T, pool *pgxpool.Pool, viewer middleware.CurrentUser, target string) (string, int, string) {
210 return performTeamsRequest(t, pool, viewer, http.MethodGet, target, nil)
211 }
212
213 func performTeamsRequest(t *testing.T, pool *pgxpool.Pool, viewer middleware.CurrentUser, method, target string, form url.Values) (string, int, string) {
214 t.Helper()
215 rr, err := render.New(fstest.MapFS{
216 "_layout.html": {Data: []byte(`{{ define "layout" }}<html><body>{{ template "page" . }}</body></html>{{ end }}`)},
217 "orgs/teams_list.html": {Data: []byte(`{{ define "page" }}ACTIVE={{ .ActiveOrgNav }} TOTAL={{ .TeamTotalCount }}{{ range .Teams }} TEAM={{ .Slug }}:{{ .DisplayName }}:{{ .MemberCount }}:{{ .RepoCount }}{{ end }}{{ end }}`)},
218 "orgs/team_view.html": {Data: []byte(`{{ define "page" }}TEAM{{ end }}`)},
219 "orgs/people.html": {Data: []byte(`{{ define "page" }}PEOPLE{{ end }}`)},
220 "errors/400.html": {Data: []byte(`{{ define "page" }}400{{ end }}`)},
221 "errors/403.html": {Data: []byte(`{{ define "page" }}403{{ end }}`)},
222 "errors/404.html": {Data: []byte(`{{ define "page" }}404{{ end }}`)},
223 "errors/500.html": {Data: []byte(`{{ define "page" }}500{{ end }}`)},
224 }, render.Options{})
225 if err != nil {
226 t.Fatalf("render.New: %v", err)
227 }
228 h, err := orgsh.New(orgsh.Deps{
229 Logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
230 Render: rr,
231 Pool: pool,
232 })
233 if err != nil {
234 t.Fatalf("orgsh.New: %v", err)
235 }
236 r := chi.NewRouter()
237 r.Use(func(next http.Handler) http.Handler {
238 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
239 next.ServeHTTP(w, r.WithContext(middleware.WithCurrentUserForTest(r.Context(), viewer)))
240 })
241 })
242 h.MountOrgRoutes(r)
243
244 var body io.Reader
245 if form != nil {
246 body = strings.NewReader(form.Encode())
247 }
248 req := httptest.NewRequest(method, target, body)
249 if form != nil {
250 req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
251 }
252 rec := httptest.NewRecorder()
253 r.ServeHTTP(rec, req)
254 return rec.Body.String(), rec.Code, rec.Header().Get("Location")
255 }
256
257 func insertTeamForTest(t *testing.T, db orgsdb.DBTX, orgID int64, slug, displayName, privacy string) int64 {
258 t.Helper()
259 var id int64
260 if err := db.QueryRow(context.Background(),
261 `INSERT INTO teams (org_id, slug, display_name, privacy)
262 VALUES ($1, $2, $3, $4)
263 RETURNING id`,
264 orgID, slug, displayName, privacy,
265 ).Scan(&id); err != nil {
266 t.Fatalf("insert team: %v", err)
267 }
268 return id
269 }
270
271 func insertTeamRepoForTest(t *testing.T, db orgsdb.DBTX, orgID int64, name string) int64 {
272 t.Helper()
273 var id int64
274 if err := db.QueryRow(context.Background(),
275 `INSERT INTO repos (owner_org_id, name, visibility, default_branch)
276 VALUES ($1, $2, 'private', 'trunk')
277 RETURNING id`,
278 orgID, name,
279 ).Scan(&id); err != nil {
280 t.Fatalf("insert repo: %v", err)
281 }
282 return id
283 }
284
285 func activateTeamPlanForTest(t *testing.T, pool *pgxpool.Pool, orgID int64) {
286 t.Helper()
287 now := time.Now().UTC().Truncate(time.Second)
288 _, err := billing.ApplySubscriptionSnapshot(context.Background(), billing.Deps{Pool: pool}, billing.SubscriptionSnapshot{
289 OrgID: orgID,
290 Plan: billing.PlanTeam,
291 Status: billing.SubscriptionStatusActive,
292 StripeSubscriptionID: "sub_teams_" + strconv.FormatInt(orgID, 10),
293 StripeSubscriptionItemID: "si_teams_" + strconv.FormatInt(orgID, 10),
294 CurrentPeriodStart: now,
295 CurrentPeriodEnd: now.Add(30 * 24 * time.Hour),
296 LastWebhookEventID: "evt_teams_" + strconv.FormatInt(orgID, 10),
297 })
298 if err != nil {
299 t.Fatalf("activate team plan: %v", err)
300 }
301 }
302
303 func assertTeamMemberCount(t *testing.T, db orgsdb.DBTX, teamID int64, want int) {
304 t.Helper()
305 var count int
306 if err := db.QueryRow(context.Background(), `SELECT count(*) FROM team_members WHERE team_id = $1`, teamID).Scan(&count); err != nil {
307 t.Fatalf("count team members: %v", err)
308 }
309 if count != want {
310 t.Fatalf("team member count=%d, want %d", count, want)
311 }
312 }
313
314 func assertTeamRepoGrantCount(t *testing.T, db orgsdb.DBTX, teamID int64, want int) {
315 t.Helper()
316 var count int
317 if err := db.QueryRow(context.Background(), `SELECT count(*) FROM team_repo_access WHERE team_id = $1`, teamID).Scan(&count); err != nil {
318 t.Fatalf("count team repo grants: %v", err)
319 }
320 if count != want {
321 t.Fatalf("team repo grant count=%d, want %d", count, want)
322 }
323 }
324