| 1 | // SPDX-License-Identifier: AGPL-3.0-or-later |
| 2 | |
| 3 | package entitlements_test |
| 4 | |
| 5 | import ( |
| 6 | "context" |
| 7 | "errors" |
| 8 | "strings" |
| 9 | "testing" |
| 10 | "time" |
| 11 | |
| 12 | "github.com/jackc/pgx/v5" |
| 13 | "github.com/jackc/pgx/v5/pgconn" |
| 14 | "github.com/jackc/pgx/v5/pgtype" |
| 15 | |
| 16 | "github.com/tenseleyFlow/shithub/internal/billing" |
| 17 | "github.com/tenseleyFlow/shithub/internal/entitlements" |
| 18 | "github.com/tenseleyFlow/shithub/internal/orgs" |
| 19 | reposdb "github.com/tenseleyFlow/shithub/internal/repos/sqlc" |
| 20 | "github.com/tenseleyFlow/shithub/internal/testing/dbtest" |
| 21 | usersdb "github.com/tenseleyFlow/shithub/internal/users/sqlc" |
| 22 | ) |
| 23 | |
| 24 | func TestPrivateCollaborationUsageCountsEffectivePrivateAccess(t *testing.T) { |
| 25 | t.Parallel() |
| 26 | ctx := context.Background() |
| 27 | pool := dbtest.NewTestDB(t) |
| 28 | owner := createEntitlementUser(t, pool, "owner") |
| 29 | org, err := orgs.Create(ctx, orgs.Deps{Pool: pool}, orgs.CreateParams{Slug: "acme", CreatedByUserID: owner.ID}) |
| 30 | if err != nil { |
| 31 | t.Fatalf("create org: %v", err) |
| 32 | } |
| 33 | privateRepo := createEntitlementOrgRepo(t, pool, org.ID, "secret", "private") |
| 34 | publicRepo := createEntitlementOrgRepo(t, pool, org.ID, "public", "public") |
| 35 | |
| 36 | plainMember := createEntitlementUser(t, pool, "plain") |
| 37 | insertOrgMember(t, pool, org.ID, plainMember.ID, "member") |
| 38 | |
| 39 | direct := createEntitlementUser(t, pool, "direct") |
| 40 | insertRepoCollaborator(t, pool, privateRepo.ID, direct.ID) |
| 41 | publicOnly := createEntitlementUser(t, pool, "publiconly") |
| 42 | insertRepoCollaborator(t, pool, publicRepo.ID, publicOnly.ID) |
| 43 | |
| 44 | parentTeamID := insertEntitlementTeam(t, pool, org.ID, "platform", 0) |
| 45 | childTeamID := insertEntitlementTeam(t, pool, org.ID, "runtime", parentTeamID) |
| 46 | childMember := createEntitlementUser(t, pool, "childmember") |
| 47 | insertTeamMember(t, pool, childTeamID, childMember.ID) |
| 48 | insertTeamRepoGrant(t, pool, parentTeamID, privateRepo.ID) |
| 49 | |
| 50 | usage, err := entitlements.PrivateCollaborationUsageForOrg(ctx, entitlements.Deps{Pool: pool}, org.ID) |
| 51 | if err != nil { |
| 52 | t.Fatalf("PrivateCollaborationUsageForOrg: %v", err) |
| 53 | } |
| 54 | if usage.Count != 3 { |
| 55 | t.Fatalf("private collaborator count=%d, want owner + direct + inherited team member", usage.Count) |
| 56 | } |
| 57 | if usage.Limit != entitlements.FreePrivateCollaborationLimit || usage.Unlimited { |
| 58 | t.Fatalf("free usage limit = %+v", usage) |
| 59 | } |
| 60 | } |
| 61 | |
| 62 | func TestPrivateCollaborationExpansionEnforcesFreeLimitAndTeamUnlimited(t *testing.T) { |
| 63 | t.Parallel() |
| 64 | ctx := context.Background() |
| 65 | pool := dbtest.NewTestDB(t) |
| 66 | owner := createEntitlementUser(t, pool, "owner") |
| 67 | org, err := orgs.Create(ctx, orgs.Deps{Pool: pool}, orgs.CreateParams{Slug: "acme", CreatedByUserID: owner.ID}) |
| 68 | if err != nil { |
| 69 | t.Fatalf("create org: %v", err) |
| 70 | } |
| 71 | createEntitlementOrgRepo(t, pool, org.ID, "secret", "private") |
| 72 | first := createEntitlementUser(t, pool, "first") |
| 73 | second := createEntitlementUser(t, pool, "second") |
| 74 | third := createEntitlementUser(t, pool, "third") |
| 75 | |
| 76 | check, err := entitlements.CheckPrivateCollaborationExpansion(ctx, entitlements.Deps{Pool: pool}, org.ID, entitlements.PrivateCollaborationExpansion{ |
| 77 | CandidateUserIDs: []int64{first.ID, second.ID}, |
| 78 | }) |
| 79 | if err != nil { |
| 80 | t.Fatalf("allowed expansion: %v", err) |
| 81 | } |
| 82 | if !check.Allowed || check.WouldUse != 3 { |
| 83 | t.Fatalf("two-user free expansion check = %+v, want allowed at limit", check) |
| 84 | } |
| 85 | |
| 86 | check, err = entitlements.CheckPrivateCollaborationExpansion(ctx, entitlements.Deps{Pool: pool}, org.ID, entitlements.PrivateCollaborationExpansion{ |
| 87 | CandidateUserIDs: []int64{first.ID, second.ID, third.ID}, |
| 88 | }) |
| 89 | if err != nil { |
| 90 | t.Fatalf("blocked expansion: %v", err) |
| 91 | } |
| 92 | if check.Allowed || check.WouldUse != 4 || !errors.Is(check.Err(), entitlements.ErrPrivateCollaborationLimitExceeded) { |
| 93 | t.Fatalf("three-user free expansion check = %+v, want blocked", check) |
| 94 | } |
| 95 | if !strings.Contains(check.Message(), "up to 3 private collaborators") { |
| 96 | t.Fatalf("message=%q, want concrete limit", check.Message()) |
| 97 | } |
| 98 | |
| 99 | now := time.Now().UTC().Truncate(time.Second) |
| 100 | if err := setSubscription(ctx, billing.Deps{Pool: pool}, org.ID, now, billing.PlanTeam, billing.SubscriptionStatusActive, "private-collab"); err != nil { |
| 101 | t.Fatalf("activate team: %v", err) |
| 102 | } |
| 103 | check, err = entitlements.CheckPrivateCollaborationExpansion(ctx, entitlements.Deps{Pool: pool, Now: func() time.Time { return now }}, org.ID, entitlements.PrivateCollaborationExpansion{ |
| 104 | CandidateUserIDs: []int64{first.ID, second.ID, third.ID}, |
| 105 | }) |
| 106 | if err != nil { |
| 107 | t.Fatalf("team expansion: %v", err) |
| 108 | } |
| 109 | if !check.Allowed || !check.Usage.Unlimited { |
| 110 | t.Fatalf("team expansion check = %+v, want unlimited", check) |
| 111 | } |
| 112 | } |
| 113 | |
| 114 | func TestPrivateRepoCreationCountsOwnersForFirstPrivateRepo(t *testing.T) { |
| 115 | t.Parallel() |
| 116 | ctx := context.Background() |
| 117 | pool := dbtest.NewTestDB(t) |
| 118 | owner := createEntitlementUser(t, pool, "owner") |
| 119 | org, err := orgs.Create(ctx, orgs.Deps{Pool: pool}, orgs.CreateParams{Slug: "acme", CreatedByUserID: owner.ID}) |
| 120 | if err != nil { |
| 121 | t.Fatalf("create org: %v", err) |
| 122 | } |
| 123 | for _, name := range []string{"owner2", "owner3", "owner4"} { |
| 124 | insertOrgMember(t, pool, org.ID, createEntitlementUser(t, pool, name).ID, "owner") |
| 125 | } |
| 126 | |
| 127 | check, err := entitlements.CheckPrivateRepositoryCreation(ctx, entitlements.Deps{Pool: pool}, org.ID) |
| 128 | if err != nil { |
| 129 | t.Fatalf("CheckPrivateRepositoryCreation: %v", err) |
| 130 | } |
| 131 | if check.Allowed || check.WouldUse != 4 { |
| 132 | t.Fatalf("first-private-repo check = %+v, want blocked by four owners", check) |
| 133 | } |
| 134 | } |
| 135 | |
| 136 | func TestRepoPrivateVisibilityCountsRepoSpecificGrants(t *testing.T) { |
| 137 | t.Parallel() |
| 138 | ctx := context.Background() |
| 139 | pool := dbtest.NewTestDB(t) |
| 140 | owner := createEntitlementUser(t, pool, "owner") |
| 141 | org, err := orgs.Create(ctx, orgs.Deps{Pool: pool}, orgs.CreateParams{Slug: "acme", CreatedByUserID: owner.ID}) |
| 142 | if err != nil { |
| 143 | t.Fatalf("create org: %v", err) |
| 144 | } |
| 145 | repo := createEntitlementOrgRepo(t, pool, org.ID, "soon-private", "public") |
| 146 | insertRepoCollaborator(t, pool, repo.ID, createEntitlementUser(t, pool, "direct").ID) |
| 147 | teamID := insertEntitlementTeam(t, pool, org.ID, "security", 0) |
| 148 | insertTeamMember(t, pool, teamID, createEntitlementUser(t, pool, "teamuser").ID) |
| 149 | insertTeamRepoGrant(t, pool, teamID, repo.ID) |
| 150 | |
| 151 | check, err := entitlements.CheckRepoPrivateVisibility(ctx, entitlements.Deps{Pool: pool}, org.ID, repo.ID) |
| 152 | if err != nil { |
| 153 | t.Fatalf("CheckRepoPrivateVisibility: %v", err) |
| 154 | } |
| 155 | if !check.Allowed || check.WouldUse != 3 { |
| 156 | t.Fatalf("public-to-private check = %+v, want owner + direct + team user allowed at limit", check) |
| 157 | } |
| 158 | |
| 159 | insertRepoCollaborator(t, pool, repo.ID, createEntitlementUser(t, pool, "extra").ID) |
| 160 | check, err = entitlements.CheckRepoPrivateVisibility(ctx, entitlements.Deps{Pool: pool}, org.ID, repo.ID) |
| 161 | if err != nil { |
| 162 | t.Fatalf("CheckRepoPrivateVisibility after extra: %v", err) |
| 163 | } |
| 164 | if check.Allowed || check.WouldUse != 4 { |
| 165 | t.Fatalf("public-to-private check with extra = %+v, want blocked", check) |
| 166 | } |
| 167 | } |
| 168 | |
| 169 | func createEntitlementUser(t *testing.T, db usersdb.DBTX, username string) usersdb.User { |
| 170 | t.Helper() |
| 171 | user, err := usersdb.New().CreateUser(context.Background(), db, usersdb.CreateUserParams{ |
| 172 | Username: username, |
| 173 | DisplayName: username, |
| 174 | PasswordHash: fixtureHash, |
| 175 | }) |
| 176 | if err != nil { |
| 177 | t.Fatalf("create user %s: %v", username, err) |
| 178 | } |
| 179 | return user |
| 180 | } |
| 181 | |
| 182 | func createEntitlementOrgRepo(t *testing.T, db reposdb.DBTX, orgID int64, name, visibility string) reposdb.Repo { |
| 183 | t.Helper() |
| 184 | repo, err := reposdb.New().CreateRepo(context.Background(), db, reposdb.CreateRepoParams{ |
| 185 | OwnerOrgID: pgtype.Int8{Int64: orgID, Valid: true}, |
| 186 | Name: name, |
| 187 | DefaultBranch: "trunk", |
| 188 | Visibility: reposdb.RepoVisibility(visibility), |
| 189 | }) |
| 190 | if err != nil { |
| 191 | t.Fatalf("create repo %s: %v", name, err) |
| 192 | } |
| 193 | return repo |
| 194 | } |
| 195 | |
| 196 | func insertOrgMember(t *testing.T, db orgsdbtx, orgID, userID int64, role string) { |
| 197 | t.Helper() |
| 198 | if _, err := db.Exec(context.Background(), `INSERT INTO org_members (org_id, user_id, role) VALUES ($1, $2, $3)`, orgID, userID, role); err != nil { |
| 199 | t.Fatalf("insert org member: %v", err) |
| 200 | } |
| 201 | } |
| 202 | |
| 203 | func insertRepoCollaborator(t *testing.T, db orgsdbtx, repoID, userID int64) { |
| 204 | t.Helper() |
| 205 | if _, err := db.Exec(context.Background(), `INSERT INTO repo_collaborators (repo_id, user_id, role) VALUES ($1, $2, 'read')`, repoID, userID); err != nil { |
| 206 | t.Fatalf("insert repo collaborator: %v", err) |
| 207 | } |
| 208 | } |
| 209 | |
| 210 | func insertEntitlementTeam(t *testing.T, db orgsdbtx, orgID int64, slug string, parentTeamID int64) int64 { |
| 211 | t.Helper() |
| 212 | var id int64 |
| 213 | if parentTeamID == 0 { |
| 214 | if err := db.QueryRow(context.Background(), `INSERT INTO teams (org_id, slug, display_name) VALUES ($1, $2, $2) RETURNING id`, orgID, slug).Scan(&id); err != nil { |
| 215 | t.Fatalf("insert team: %v", err) |
| 216 | } |
| 217 | return id |
| 218 | } |
| 219 | if err := db.QueryRow(context.Background(), `INSERT INTO teams (org_id, slug, display_name, parent_team_id) VALUES ($1, $2, $2, $3) RETURNING id`, orgID, slug, parentTeamID).Scan(&id); err != nil { |
| 220 | t.Fatalf("insert child team: %v", err) |
| 221 | } |
| 222 | return id |
| 223 | } |
| 224 | |
| 225 | func insertTeamMember(t *testing.T, db orgsdbtx, teamID, userID int64) { |
| 226 | t.Helper() |
| 227 | if _, err := db.Exec(context.Background(), `INSERT INTO team_members (team_id, user_id, role) VALUES ($1, $2, 'member')`, teamID, userID); err != nil { |
| 228 | t.Fatalf("insert team member: %v", err) |
| 229 | } |
| 230 | } |
| 231 | |
| 232 | func insertTeamRepoGrant(t *testing.T, db orgsdbtx, teamID, repoID int64) { |
| 233 | t.Helper() |
| 234 | if _, err := db.Exec(context.Background(), `INSERT INTO team_repo_access (team_id, repo_id, role) VALUES ($1, $2, 'read')`, teamID, repoID); err != nil { |
| 235 | t.Fatalf("insert team repo grant: %v", err) |
| 236 | } |
| 237 | } |
| 238 | |
| 239 | type orgsdbtx interface { |
| 240 | Exec(context.Context, string, ...any) (pgconn.CommandTag, error) |
| 241 | QueryRow(context.Context, string, ...any) pgx.Row |
| 242 | } |
| 243 |