Go · 14041 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 package entitlements
4
5 import (
6 "context"
7 "errors"
8 "fmt"
9 "net/url"
10
11 "github.com/jackc/pgx/v5"
12 "github.com/jackc/pgx/v5/pgtype"
13 "github.com/jackc/pgx/v5/pgxpool"
14
15 "github.com/tenseleyFlow/shithub/internal/billing"
16 )
17
18 var ErrPrivateCollaborationLimitExceeded = errors.New("entitlements: private collaboration limit exceeded")
19
20 type PrivateCollaborationExpansion struct {
21 CandidateUserIDs []int64
22 AnonymousCandidates int64
23 }
24
25 type PrivateCollaborationUsage struct {
26 OrgID int64
27 Count int64
28 Limit int64
29 Unlimited bool
30 RequiredPlan billing.Plan
31 Reason Reason
32 }
33
34 type PrivateCollaborationCheck struct {
35 Allowed bool
36 Usage PrivateCollaborationUsage
37 Added int64
38 WouldUse int64
39 RequiredPlan billing.Plan
40 Reason Reason
41 }
42
43 type PrivateCollaborationLimitError struct {
44 Check PrivateCollaborationCheck
45 }
46
47 func (e *PrivateCollaborationLimitError) Error() string {
48 if e == nil {
49 return ErrPrivateCollaborationLimitExceeded.Error()
50 }
51 if msg := e.Check.Message(); msg != "" {
52 return msg
53 }
54 return ErrPrivateCollaborationLimitExceeded.Error()
55 }
56
57 func (e *PrivateCollaborationLimitError) Unwrap() error {
58 return ErrPrivateCollaborationLimitExceeded
59 }
60
61 func PrivateCollaborationUsageForOrg(ctx context.Context, deps Deps, orgID int64) (PrivateCollaborationUsage, error) {
62 usage, _, err := privateCollaborationUsageWithIDs(ctx, deps, orgID)
63 return usage, err
64 }
65
66 func CheckPrivateCollaborationExpansion(ctx context.Context, deps Deps, orgID int64, expansion PrivateCollaborationExpansion) (PrivateCollaborationCheck, error) {
67 usage, current, err := privateCollaborationUsageWithIDs(ctx, deps, orgID)
68 if err != nil {
69 return PrivateCollaborationCheck{}, err
70 }
71 check := PrivateCollaborationCheck{
72 Allowed: true,
73 Usage: usage,
74 WouldUse: usage.Count,
75 RequiredPlan: usage.RequiredPlan,
76 Reason: usage.Reason,
77 }
78 if usage.Unlimited {
79 return check, nil
80 }
81 added := expansion.AnonymousCandidates
82 if added < 0 {
83 added = 0
84 }
85 for _, userID := range expansion.CandidateUserIDs {
86 if userID == 0 {
87 continue
88 }
89 if _, ok := current[userID]; ok {
90 continue
91 }
92 current[userID] = struct{}{}
93 added++
94 }
95 check.Added = added
96 check.WouldUse = usage.Count + added
97 if added == 0 {
98 return check, nil
99 }
100 if check.WouldUse > usage.Limit {
101 check.Allowed = false
102 }
103 return check, nil
104 }
105
106 func CheckPrivateRepositoryCreation(ctx context.Context, deps Deps, orgID int64) (PrivateCollaborationCheck, error) {
107 usage, _, err := privateCollaborationUsageWithIDs(ctx, deps, orgID)
108 if err != nil {
109 return PrivateCollaborationCheck{}, err
110 }
111 check := PrivateCollaborationCheck{
112 Allowed: true,
113 Usage: usage,
114 WouldUse: usage.Count,
115 RequiredPlan: usage.RequiredPlan,
116 Reason: usage.Reason,
117 }
118 if usage.Unlimited {
119 return check, nil
120 }
121 if usage.Count > usage.Limit {
122 check.Allowed = false
123 return check, nil
124 }
125 if usage.Count > 0 {
126 return check, nil
127 }
128 owners, err := orgOwnerIDs(ctx, deps.Pool, orgID)
129 if err != nil {
130 return PrivateCollaborationCheck{}, err
131 }
132 return CheckPrivateCollaborationExpansion(ctx, deps, orgID, PrivateCollaborationExpansion{CandidateUserIDs: owners})
133 }
134
135 func CheckRepoPrivateVisibility(ctx context.Context, deps Deps, orgID, repoID int64) (PrivateCollaborationCheck, error) {
136 usage, _, err := privateCollaborationUsageWithIDs(ctx, deps, orgID)
137 if err != nil {
138 return PrivateCollaborationCheck{}, err
139 }
140 if !usage.Unlimited && usage.Count > usage.Limit {
141 return PrivateCollaborationCheck{
142 Allowed: false,
143 Usage: usage,
144 WouldUse: usage.Count,
145 RequiredPlan: usage.RequiredPlan,
146 Reason: usage.Reason,
147 }, nil
148 }
149 candidates, err := repoPrivateCollaboratorCandidateIDs(ctx, deps.Pool, orgID, repoID)
150 if err != nil {
151 return PrivateCollaborationCheck{}, err
152 }
153 return CheckPrivateCollaborationExpansion(ctx, deps, orgID, PrivateCollaborationExpansion{CandidateUserIDs: candidates})
154 }
155
156 func CheckOrgOwnerPrivateCollaboration(ctx context.Context, deps Deps, orgID, userID int64) (PrivateCollaborationCheck, error) {
157 hasPrivate, err := orgHasPrivateRepos(ctx, deps.Pool, orgID)
158 if err != nil {
159 return PrivateCollaborationCheck{}, err
160 }
161 if !hasPrivate {
162 return allowedPrivateCollaborationCheck(ctx, deps, orgID)
163 }
164 return CheckPrivateCollaborationExpansion(ctx, deps, orgID, PrivateCollaborationExpansion{CandidateUserIDs: []int64{userID}})
165 }
166
167 func CheckPrivateInvitationSlot(ctx context.Context, deps Deps, orgID int64) (PrivateCollaborationCheck, error) {
168 hasPrivate, err := orgHasPrivateRepos(ctx, deps.Pool, orgID)
169 if err != nil {
170 return PrivateCollaborationCheck{}, err
171 }
172 if !hasPrivate {
173 return allowedPrivateCollaborationCheck(ctx, deps, orgID)
174 }
175 return CheckPrivateCollaborationExpansion(ctx, deps, orgID, PrivateCollaborationExpansion{AnonymousCandidates: 1})
176 }
177
178 func CheckDirectPrivateCollaborator(ctx context.Context, deps Deps, repoID, userID int64) (PrivateCollaborationCheck, error) {
179 orgID, private, err := orgRepoPrivateState(ctx, deps.Pool, repoID)
180 if err != nil {
181 return PrivateCollaborationCheck{}, err
182 }
183 if orgID == 0 || !private {
184 return allowedPrivateCollaborationCheck(ctx, deps, orgID)
185 }
186 return CheckPrivateCollaborationExpansion(ctx, deps, orgID, PrivateCollaborationExpansion{CandidateUserIDs: []int64{userID}})
187 }
188
189 func CheckTeamMemberPrivateCollaboration(ctx context.Context, deps Deps, teamID, userID int64) (PrivateCollaborationCheck, error) {
190 orgID, hasPrivateAccess, err := teamPrivateRepoAccessState(ctx, deps.Pool, teamID)
191 if err != nil {
192 return PrivateCollaborationCheck{}, err
193 }
194 if !hasPrivateAccess {
195 return allowedPrivateCollaborationCheck(ctx, deps, orgID)
196 }
197 return CheckPrivateCollaborationExpansion(ctx, deps, orgID, PrivateCollaborationExpansion{CandidateUserIDs: []int64{userID}})
198 }
199
200 func CheckTeamPrivateRepoGrant(ctx context.Context, deps Deps, teamID, repoID int64) (PrivateCollaborationCheck, error) {
201 orgID, private, err := orgRepoPrivateState(ctx, deps.Pool, repoID)
202 if err != nil {
203 return PrivateCollaborationCheck{}, err
204 }
205 if orgID == 0 || !private {
206 return allowedPrivateCollaborationCheck(ctx, deps, orgID)
207 }
208 candidates, err := teamGrantCandidateIDs(ctx, deps.Pool, teamID)
209 if err != nil {
210 return PrivateCollaborationCheck{}, err
211 }
212 return CheckPrivateCollaborationExpansion(ctx, deps, orgID, PrivateCollaborationExpansion{CandidateUserIDs: candidates})
213 }
214
215 func (c PrivateCollaborationCheck) Err() error {
216 if c.Allowed {
217 return nil
218 }
219 return &PrivateCollaborationLimitError{Check: c}
220 }
221
222 func (c PrivateCollaborationCheck) Message() string {
223 if c.Allowed || c.Usage.Unlimited {
224 return ""
225 }
226 return fmt.Sprintf("Free organizations can have up to %d private collaborators. This change would use %d. Upgrade to Team to add more.", c.Usage.Limit, c.WouldUse)
227 }
228
229 func (c PrivateCollaborationCheck) BillingPath(orgSlug string) string {
230 return "/organizations/" + url.PathEscape(orgSlug) + "/settings/billing"
231 }
232
233 func (c PrivateCollaborationCheck) UpgradeBanner(orgSlug string) UpgradeBanner {
234 return UpgradeBanner{
235 Message: c.Message(),
236 ActionText: "Manage billing and plans",
237 ActionHref: c.BillingPath(orgSlug),
238 StatusCode: c.HTTPStatus(),
239 }
240 }
241
242 func (c PrivateCollaborationCheck) HTTPStatus() int {
243 if c.Allowed {
244 return 200
245 }
246 return 402
247 }
248
249 func allowedPrivateCollaborationCheck(ctx context.Context, deps Deps, orgID int64) (PrivateCollaborationCheck, error) {
250 if orgID == 0 {
251 return PrivateCollaborationCheck{Allowed: true}, nil
252 }
253 usage, err := PrivateCollaborationUsageForOrg(ctx, deps, orgID)
254 if err != nil {
255 return PrivateCollaborationCheck{}, err
256 }
257 return PrivateCollaborationCheck{
258 Allowed: true,
259 Usage: usage,
260 WouldUse: usage.Count,
261 RequiredPlan: usage.RequiredPlan,
262 Reason: usage.Reason,
263 }, nil
264 }
265
266 func privateCollaborationUsageWithIDs(ctx context.Context, deps Deps, orgID int64) (PrivateCollaborationUsage, map[int64]struct{}, error) {
267 if deps.Pool == nil {
268 return PrivateCollaborationUsage{}, nil, ErrPoolRequired
269 }
270 if orgID == 0 {
271 return PrivateCollaborationUsage{}, nil, ErrOrgIDRequired
272 }
273 set, err := ForOrg(ctx, deps, orgID)
274 if err != nil {
275 return PrivateCollaborationUsage{}, nil, err
276 }
277 limit, err := set.Limit(LimitOrgPrivateCollaboration)
278 if err != nil {
279 return PrivateCollaborationUsage{}, nil, err
280 }
281 ids, err := currentPrivateCollaboratorIDs(ctx, deps.Pool, orgID)
282 if err != nil {
283 return PrivateCollaborationUsage{}, nil, err
284 }
285 return PrivateCollaborationUsage{
286 OrgID: orgID,
287 Count: int64(len(ids)),
288 Limit: limit.Value,
289 Unlimited: limit.Unlimited,
290 RequiredPlan: limit.RequiredPlan,
291 Reason: limit.Reason,
292 }, ids, nil
293 }
294
295 func currentPrivateCollaboratorIDs(ctx context.Context, pool *pgxpool.Pool, orgID int64) (map[int64]struct{}, error) {
296 return queryIDSet(ctx, pool, `
297 WITH private_repos AS (
298 SELECT id
299 FROM repos
300 WHERE owner_org_id = $1
301 AND visibility = 'private'
302 AND deleted_at IS NULL
303 ),
304 granting_teams AS (
305 SELECT DISTINCT tra.team_id
306 FROM team_repo_access tra
307 JOIN teams t ON t.id = tra.team_id AND t.org_id = $1
308 JOIN private_repos pr ON pr.id = tra.repo_id
309 )
310 SELECT DISTINCT user_id
311 FROM (
312 SELECT om.user_id
313 FROM org_members om
314 WHERE om.org_id = $1
315 AND om.role = 'owner'
316 AND EXISTS (SELECT 1 FROM private_repos)
317 UNION
318 SELECT rc.user_id
319 FROM repo_collaborators rc
320 JOIN private_repos pr ON pr.id = rc.repo_id
321 UNION
322 SELECT tm.user_id
323 FROM team_members tm
324 JOIN teams member_team ON member_team.id = tm.team_id AND member_team.org_id = $1
325 JOIN granting_teams gt ON gt.team_id = member_team.id OR gt.team_id = member_team.parent_team_id
326 ) effective
327 WHERE user_id IS NOT NULL`, orgID)
328 }
329
330 func repoPrivateCollaboratorCandidateIDs(ctx context.Context, pool *pgxpool.Pool, orgID, repoID int64) ([]int64, error) {
331 ids, err := queryIDSet(ctx, pool, `
332 WITH granting_teams AS (
333 SELECT DISTINCT tra.team_id
334 FROM team_repo_access tra
335 JOIN teams t ON t.id = tra.team_id AND t.org_id = $1
336 WHERE tra.repo_id = $2
337 )
338 SELECT DISTINCT user_id
339 FROM (
340 SELECT om.user_id
341 FROM org_members om
342 WHERE om.org_id = $1
343 AND om.role = 'owner'
344 UNION
345 SELECT rc.user_id
346 FROM repo_collaborators rc
347 WHERE rc.repo_id = $2
348 UNION
349 SELECT tm.user_id
350 FROM team_members tm
351 JOIN teams member_team ON member_team.id = tm.team_id AND member_team.org_id = $1
352 JOIN granting_teams gt ON gt.team_id = member_team.id OR gt.team_id = member_team.parent_team_id
353 ) effective
354 WHERE user_id IS NOT NULL`, orgID, repoID)
355 if err != nil {
356 return nil, err
357 }
358 return idSetToSlice(ids), nil
359 }
360
361 func teamGrantCandidateIDs(ctx context.Context, pool *pgxpool.Pool, teamID int64) ([]int64, error) {
362 ids, err := queryIDSet(ctx, pool, `
363 SELECT DISTINCT tm.user_id
364 FROM teams grant_team
365 JOIN teams member_team ON member_team.id = grant_team.id OR member_team.parent_team_id = grant_team.id
366 JOIN team_members tm ON tm.team_id = member_team.id
367 WHERE grant_team.id = $1`, teamID)
368 if err != nil {
369 return nil, err
370 }
371 return idSetToSlice(ids), nil
372 }
373
374 func orgOwnerIDs(ctx context.Context, pool *pgxpool.Pool, orgID int64) ([]int64, error) {
375 ids, err := queryIDSet(ctx, pool, `SELECT user_id FROM org_members WHERE org_id = $1 AND role = 'owner'`, orgID)
376 if err != nil {
377 return nil, err
378 }
379 return idSetToSlice(ids), nil
380 }
381
382 func orgHasPrivateRepos(ctx context.Context, pool *pgxpool.Pool, orgID int64) (bool, error) {
383 if pool == nil {
384 return false, ErrPoolRequired
385 }
386 var exists bool
387 err := pool.QueryRow(ctx, `SELECT EXISTS(SELECT 1 FROM repos WHERE owner_org_id = $1 AND visibility = 'private' AND deleted_at IS NULL)`, orgID).Scan(&exists)
388 return exists, err
389 }
390
391 func orgRepoPrivateState(ctx context.Context, pool *pgxpool.Pool, repoID int64) (int64, bool, error) {
392 if pool == nil {
393 return 0, false, ErrPoolRequired
394 }
395 var ownerOrgID pgtype.Int8
396 var visibility string
397 err := pool.QueryRow(ctx, `SELECT owner_org_id, visibility::text FROM repos WHERE id = $1 AND deleted_at IS NULL`, repoID).Scan(&ownerOrgID, &visibility)
398 if errors.Is(err, pgx.ErrNoRows) {
399 return 0, false, nil
400 }
401 if err != nil {
402 return 0, false, err
403 }
404 if !ownerOrgID.Valid {
405 return 0, visibility == "private", nil
406 }
407 return ownerOrgID.Int64, visibility == "private", nil
408 }
409
410 func teamPrivateRepoAccessState(ctx context.Context, pool *pgxpool.Pool, teamID int64) (int64, bool, error) {
411 if pool == nil {
412 return 0, false, ErrPoolRequired
413 }
414 var orgID int64
415 var hasPrivateAccess bool
416 err := pool.QueryRow(ctx, `
417 SELECT t.org_id,
418 EXISTS(
419 SELECT 1
420 FROM team_repo_access tra
421 JOIN repos r ON r.id = tra.repo_id
422 WHERE r.owner_org_id = t.org_id
423 AND r.visibility = 'private'
424 AND r.deleted_at IS NULL
425 AND (tra.team_id = t.id OR tra.team_id = t.parent_team_id)
426 )
427 FROM teams t
428 WHERE t.id = $1`, teamID).Scan(&orgID, &hasPrivateAccess)
429 if errors.Is(err, pgx.ErrNoRows) {
430 return 0, false, nil
431 }
432 return orgID, hasPrivateAccess, err
433 }
434
435 func queryIDSet(ctx context.Context, pool *pgxpool.Pool, query string, args ...any) (map[int64]struct{}, error) {
436 if pool == nil {
437 return nil, ErrPoolRequired
438 }
439 rows, err := pool.Query(ctx, query, args...)
440 if err != nil {
441 return nil, err
442 }
443 defer rows.Close()
444 ids := make(map[int64]struct{})
445 for rows.Next() {
446 var id int64
447 if err := rows.Scan(&id); err != nil {
448 return nil, err
449 }
450 ids[id] = struct{}{}
451 }
452 return ids, rows.Err()
453 }
454
455 func idSetToSlice(ids map[int64]struct{}) []int64 {
456 out := make([]int64, 0, len(ids))
457 for id := range ids {
458 out = append(out, id)
459 }
460 return out
461 }
462