Go · 13605 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 package entitlements
4
5 import (
6 "context"
7 "errors"
8 "fmt"
9 "net/url"
10
11 "github.com/jackc/pgx/v5"
12 "github.com/jackc/pgx/v5/pgtype"
13 "github.com/jackc/pgx/v5/pgxpool"
14
15 "github.com/tenseleyFlow/shithub/internal/billing"
16 )
17
18 var ErrPrivateCollaborationLimitExceeded = errors.New("entitlements: private collaboration limit exceeded")
19
20 type PrivateCollaborationExpansion struct {
21 CandidateUserIDs []int64
22 AnonymousCandidates int64
23 }
24
25 type PrivateCollaborationUsage struct {
26 OrgID int64
27 Count int64
28 Limit int64
29 Unlimited bool
30 RequiredPlan billing.Plan
31 Reason Reason
32 }
33
34 type PrivateCollaborationCheck struct {
35 Allowed bool
36 Usage PrivateCollaborationUsage
37 Added int64
38 WouldUse int64
39 RequiredPlan billing.Plan
40 Reason Reason
41 }
42
43 func PrivateCollaborationUsageForOrg(ctx context.Context, deps Deps, orgID int64) (PrivateCollaborationUsage, error) {
44 usage, _, err := privateCollaborationUsageWithIDs(ctx, deps, orgID)
45 return usage, err
46 }
47
48 func CheckPrivateCollaborationExpansion(ctx context.Context, deps Deps, orgID int64, expansion PrivateCollaborationExpansion) (PrivateCollaborationCheck, error) {
49 usage, current, err := privateCollaborationUsageWithIDs(ctx, deps, orgID)
50 if err != nil {
51 return PrivateCollaborationCheck{}, err
52 }
53 check := PrivateCollaborationCheck{
54 Allowed: true,
55 Usage: usage,
56 WouldUse: usage.Count,
57 RequiredPlan: usage.RequiredPlan,
58 Reason: usage.Reason,
59 }
60 if usage.Unlimited {
61 return check, nil
62 }
63 added := expansion.AnonymousCandidates
64 if added < 0 {
65 added = 0
66 }
67 for _, userID := range expansion.CandidateUserIDs {
68 if userID == 0 {
69 continue
70 }
71 if _, ok := current[userID]; ok {
72 continue
73 }
74 current[userID] = struct{}{}
75 added++
76 }
77 check.Added = added
78 check.WouldUse = usage.Count + added
79 if added == 0 {
80 return check, nil
81 }
82 if check.WouldUse > usage.Limit {
83 check.Allowed = false
84 }
85 return check, nil
86 }
87
88 func CheckPrivateRepositoryCreation(ctx context.Context, deps Deps, orgID int64) (PrivateCollaborationCheck, error) {
89 usage, _, err := privateCollaborationUsageWithIDs(ctx, deps, orgID)
90 if err != nil {
91 return PrivateCollaborationCheck{}, err
92 }
93 check := PrivateCollaborationCheck{
94 Allowed: true,
95 Usage: usage,
96 WouldUse: usage.Count,
97 RequiredPlan: usage.RequiredPlan,
98 Reason: usage.Reason,
99 }
100 if usage.Unlimited {
101 return check, nil
102 }
103 if usage.Count > usage.Limit {
104 check.Allowed = false
105 return check, nil
106 }
107 if usage.Count > 0 {
108 return check, nil
109 }
110 owners, err := orgOwnerIDs(ctx, deps.Pool, orgID)
111 if err != nil {
112 return PrivateCollaborationCheck{}, err
113 }
114 return CheckPrivateCollaborationExpansion(ctx, deps, orgID, PrivateCollaborationExpansion{CandidateUserIDs: owners})
115 }
116
117 func CheckRepoPrivateVisibility(ctx context.Context, deps Deps, orgID, repoID int64) (PrivateCollaborationCheck, error) {
118 usage, _, err := privateCollaborationUsageWithIDs(ctx, deps, orgID)
119 if err != nil {
120 return PrivateCollaborationCheck{}, err
121 }
122 if !usage.Unlimited && usage.Count > usage.Limit {
123 return PrivateCollaborationCheck{
124 Allowed: false,
125 Usage: usage,
126 WouldUse: usage.Count,
127 RequiredPlan: usage.RequiredPlan,
128 Reason: usage.Reason,
129 }, nil
130 }
131 candidates, err := repoPrivateCollaboratorCandidateIDs(ctx, deps.Pool, orgID, repoID)
132 if err != nil {
133 return PrivateCollaborationCheck{}, err
134 }
135 return CheckPrivateCollaborationExpansion(ctx, deps, orgID, PrivateCollaborationExpansion{CandidateUserIDs: candidates})
136 }
137
138 func CheckOrgOwnerPrivateCollaboration(ctx context.Context, deps Deps, orgID, userID int64) (PrivateCollaborationCheck, error) {
139 hasPrivate, err := orgHasPrivateRepos(ctx, deps.Pool, orgID)
140 if err != nil {
141 return PrivateCollaborationCheck{}, err
142 }
143 if !hasPrivate {
144 return allowedPrivateCollaborationCheck(ctx, deps, orgID)
145 }
146 return CheckPrivateCollaborationExpansion(ctx, deps, orgID, PrivateCollaborationExpansion{CandidateUserIDs: []int64{userID}})
147 }
148
149 func CheckPrivateInvitationSlot(ctx context.Context, deps Deps, orgID int64) (PrivateCollaborationCheck, error) {
150 hasPrivate, err := orgHasPrivateRepos(ctx, deps.Pool, orgID)
151 if err != nil {
152 return PrivateCollaborationCheck{}, err
153 }
154 if !hasPrivate {
155 return allowedPrivateCollaborationCheck(ctx, deps, orgID)
156 }
157 return CheckPrivateCollaborationExpansion(ctx, deps, orgID, PrivateCollaborationExpansion{AnonymousCandidates: 1})
158 }
159
160 func CheckDirectPrivateCollaborator(ctx context.Context, deps Deps, repoID, userID int64) (PrivateCollaborationCheck, error) {
161 orgID, private, err := orgRepoPrivateState(ctx, deps.Pool, repoID)
162 if err != nil {
163 return PrivateCollaborationCheck{}, err
164 }
165 if orgID == 0 || !private {
166 return allowedPrivateCollaborationCheck(ctx, deps, orgID)
167 }
168 return CheckPrivateCollaborationExpansion(ctx, deps, orgID, PrivateCollaborationExpansion{CandidateUserIDs: []int64{userID}})
169 }
170
171 func CheckTeamMemberPrivateCollaboration(ctx context.Context, deps Deps, teamID, userID int64) (PrivateCollaborationCheck, error) {
172 orgID, hasPrivateAccess, err := teamPrivateRepoAccessState(ctx, deps.Pool, teamID)
173 if err != nil {
174 return PrivateCollaborationCheck{}, err
175 }
176 if !hasPrivateAccess {
177 return allowedPrivateCollaborationCheck(ctx, deps, orgID)
178 }
179 return CheckPrivateCollaborationExpansion(ctx, deps, orgID, PrivateCollaborationExpansion{CandidateUserIDs: []int64{userID}})
180 }
181
182 func CheckTeamPrivateRepoGrant(ctx context.Context, deps Deps, teamID, repoID int64) (PrivateCollaborationCheck, error) {
183 orgID, private, err := orgRepoPrivateState(ctx, deps.Pool, repoID)
184 if err != nil {
185 return PrivateCollaborationCheck{}, err
186 }
187 if orgID == 0 || !private {
188 return allowedPrivateCollaborationCheck(ctx, deps, orgID)
189 }
190 candidates, err := teamGrantCandidateIDs(ctx, deps.Pool, teamID)
191 if err != nil {
192 return PrivateCollaborationCheck{}, err
193 }
194 return CheckPrivateCollaborationExpansion(ctx, deps, orgID, PrivateCollaborationExpansion{CandidateUserIDs: candidates})
195 }
196
197 func (c PrivateCollaborationCheck) Err() error {
198 if c.Allowed {
199 return nil
200 }
201 return ErrPrivateCollaborationLimitExceeded
202 }
203
204 func (c PrivateCollaborationCheck) Message() string {
205 if c.Allowed || c.Usage.Unlimited {
206 return ""
207 }
208 return fmt.Sprintf("Free organizations can have up to %d private collaborators. This change would use %d. Upgrade to Team to add more.", c.Usage.Limit, c.WouldUse)
209 }
210
211 func (c PrivateCollaborationCheck) BillingPath(orgSlug string) string {
212 return "/organizations/" + url.PathEscape(orgSlug) + "/settings/billing"
213 }
214
215 func (c PrivateCollaborationCheck) UpgradeBanner(orgSlug string) UpgradeBanner {
216 return UpgradeBanner{
217 Message: c.Message(),
218 ActionText: "Manage billing and plans",
219 ActionHref: c.BillingPath(orgSlug),
220 StatusCode: c.HTTPStatus(),
221 }
222 }
223
224 func (c PrivateCollaborationCheck) HTTPStatus() int {
225 if c.Allowed {
226 return 200
227 }
228 return 402
229 }
230
231 func allowedPrivateCollaborationCheck(ctx context.Context, deps Deps, orgID int64) (PrivateCollaborationCheck, error) {
232 if orgID == 0 {
233 return PrivateCollaborationCheck{Allowed: true}, nil
234 }
235 usage, err := PrivateCollaborationUsageForOrg(ctx, deps, orgID)
236 if err != nil {
237 return PrivateCollaborationCheck{}, err
238 }
239 return PrivateCollaborationCheck{
240 Allowed: true,
241 Usage: usage,
242 WouldUse: usage.Count,
243 RequiredPlan: usage.RequiredPlan,
244 Reason: usage.Reason,
245 }, nil
246 }
247
248 func privateCollaborationUsageWithIDs(ctx context.Context, deps Deps, orgID int64) (PrivateCollaborationUsage, map[int64]struct{}, error) {
249 if deps.Pool == nil {
250 return PrivateCollaborationUsage{}, nil, ErrPoolRequired
251 }
252 if orgID == 0 {
253 return PrivateCollaborationUsage{}, nil, ErrOrgIDRequired
254 }
255 set, err := ForOrg(ctx, deps, orgID)
256 if err != nil {
257 return PrivateCollaborationUsage{}, nil, err
258 }
259 limit, err := set.Limit(LimitOrgPrivateCollaboration)
260 if err != nil {
261 return PrivateCollaborationUsage{}, nil, err
262 }
263 ids, err := currentPrivateCollaboratorIDs(ctx, deps.Pool, orgID)
264 if err != nil {
265 return PrivateCollaborationUsage{}, nil, err
266 }
267 return PrivateCollaborationUsage{
268 OrgID: orgID,
269 Count: int64(len(ids)),
270 Limit: limit.Value,
271 Unlimited: limit.Unlimited,
272 RequiredPlan: limit.RequiredPlan,
273 Reason: limit.Reason,
274 }, ids, nil
275 }
276
277 func currentPrivateCollaboratorIDs(ctx context.Context, pool *pgxpool.Pool, orgID int64) (map[int64]struct{}, error) {
278 return queryIDSet(ctx, pool, `
279 WITH private_repos AS (
280 SELECT id
281 FROM repos
282 WHERE owner_org_id = $1
283 AND visibility = 'private'
284 AND deleted_at IS NULL
285 ),
286 granting_teams AS (
287 SELECT DISTINCT tra.team_id
288 FROM team_repo_access tra
289 JOIN teams t ON t.id = tra.team_id AND t.org_id = $1
290 JOIN private_repos pr ON pr.id = tra.repo_id
291 )
292 SELECT DISTINCT user_id
293 FROM (
294 SELECT om.user_id
295 FROM org_members om
296 WHERE om.org_id = $1
297 AND om.role = 'owner'
298 AND EXISTS (SELECT 1 FROM private_repos)
299 UNION
300 SELECT rc.user_id
301 FROM repo_collaborators rc
302 JOIN private_repos pr ON pr.id = rc.repo_id
303 UNION
304 SELECT tm.user_id
305 FROM team_members tm
306 JOIN teams member_team ON member_team.id = tm.team_id AND member_team.org_id = $1
307 JOIN granting_teams gt ON gt.team_id = member_team.id OR gt.team_id = member_team.parent_team_id
308 ) effective
309 WHERE user_id IS NOT NULL`, orgID)
310 }
311
312 func repoPrivateCollaboratorCandidateIDs(ctx context.Context, pool *pgxpool.Pool, orgID, repoID int64) ([]int64, error) {
313 ids, err := queryIDSet(ctx, pool, `
314 WITH granting_teams AS (
315 SELECT DISTINCT tra.team_id
316 FROM team_repo_access tra
317 JOIN teams t ON t.id = tra.team_id AND t.org_id = $1
318 WHERE tra.repo_id = $2
319 )
320 SELECT DISTINCT user_id
321 FROM (
322 SELECT om.user_id
323 FROM org_members om
324 WHERE om.org_id = $1
325 AND om.role = 'owner'
326 UNION
327 SELECT rc.user_id
328 FROM repo_collaborators rc
329 WHERE rc.repo_id = $2
330 UNION
331 SELECT tm.user_id
332 FROM team_members tm
333 JOIN teams member_team ON member_team.id = tm.team_id AND member_team.org_id = $1
334 JOIN granting_teams gt ON gt.team_id = member_team.id OR gt.team_id = member_team.parent_team_id
335 ) effective
336 WHERE user_id IS NOT NULL`, orgID, repoID)
337 if err != nil {
338 return nil, err
339 }
340 return idSetToSlice(ids), nil
341 }
342
343 func teamGrantCandidateIDs(ctx context.Context, pool *pgxpool.Pool, teamID int64) ([]int64, error) {
344 ids, err := queryIDSet(ctx, pool, `
345 SELECT DISTINCT tm.user_id
346 FROM teams grant_team
347 JOIN teams member_team ON member_team.id = grant_team.id OR member_team.parent_team_id = grant_team.id
348 JOIN team_members tm ON tm.team_id = member_team.id
349 WHERE grant_team.id = $1`, teamID)
350 if err != nil {
351 return nil, err
352 }
353 return idSetToSlice(ids), nil
354 }
355
356 func orgOwnerIDs(ctx context.Context, pool *pgxpool.Pool, orgID int64) ([]int64, error) {
357 ids, err := queryIDSet(ctx, pool, `SELECT user_id FROM org_members WHERE org_id = $1 AND role = 'owner'`, orgID)
358 if err != nil {
359 return nil, err
360 }
361 return idSetToSlice(ids), nil
362 }
363
364 func orgHasPrivateRepos(ctx context.Context, pool *pgxpool.Pool, orgID int64) (bool, error) {
365 if pool == nil {
366 return false, ErrPoolRequired
367 }
368 var exists bool
369 err := pool.QueryRow(ctx, `SELECT EXISTS(SELECT 1 FROM repos WHERE owner_org_id = $1 AND visibility = 'private' AND deleted_at IS NULL)`, orgID).Scan(&exists)
370 return exists, err
371 }
372
373 func orgRepoPrivateState(ctx context.Context, pool *pgxpool.Pool, repoID int64) (int64, bool, error) {
374 if pool == nil {
375 return 0, false, ErrPoolRequired
376 }
377 var ownerOrgID pgtype.Int8
378 var visibility string
379 err := pool.QueryRow(ctx, `SELECT owner_org_id, visibility::text FROM repos WHERE id = $1 AND deleted_at IS NULL`, repoID).Scan(&ownerOrgID, &visibility)
380 if errors.Is(err, pgx.ErrNoRows) {
381 return 0, false, nil
382 }
383 if err != nil {
384 return 0, false, err
385 }
386 if !ownerOrgID.Valid {
387 return 0, visibility == "private", nil
388 }
389 return ownerOrgID.Int64, visibility == "private", nil
390 }
391
392 func teamPrivateRepoAccessState(ctx context.Context, pool *pgxpool.Pool, teamID int64) (int64, bool, error) {
393 if pool == nil {
394 return 0, false, ErrPoolRequired
395 }
396 var orgID int64
397 var hasPrivateAccess bool
398 err := pool.QueryRow(ctx, `
399 SELECT t.org_id,
400 EXISTS(
401 SELECT 1
402 FROM team_repo_access tra
403 JOIN repos r ON r.id = tra.repo_id
404 WHERE r.owner_org_id = t.org_id
405 AND r.visibility = 'private'
406 AND r.deleted_at IS NULL
407 AND (tra.team_id = t.id OR tra.team_id = t.parent_team_id)
408 )
409 FROM teams t
410 WHERE t.id = $1`, teamID).Scan(&orgID, &hasPrivateAccess)
411 if errors.Is(err, pgx.ErrNoRows) {
412 return 0, false, nil
413 }
414 return orgID, hasPrivateAccess, err
415 }
416
417 func queryIDSet(ctx context.Context, pool *pgxpool.Pool, query string, args ...any) (map[int64]struct{}, error) {
418 if pool == nil {
419 return nil, ErrPoolRequired
420 }
421 rows, err := pool.Query(ctx, query, args...)
422 if err != nil {
423 return nil, err
424 }
425 defer rows.Close()
426 ids := make(map[int64]struct{})
427 for rows.Next() {
428 var id int64
429 if err := rows.Scan(&id); err != nil {
430 return nil, err
431 }
432 ids[id] = struct{}{}
433 }
434 return ids, rows.Err()
435 }
436
437 func idSetToSlice(ids map[int64]struct{}) []int64 {
438 out := make([]int64, 0, len(ids))
439 for id := range ids {
440 out = append(out, id)
441 }
442 return out
443 }
444