| 1 | // SPDX-License-Identifier: AGPL-3.0-or-later |
| 2 | |
| 3 | package entitlements_test |
| 4 | |
| 5 | import ( |
| 6 | "context" |
| 7 | "strings" |
| 8 | "testing" |
| 9 | "time" |
| 10 | |
| 11 | "github.com/jackc/pgx/v5" |
| 12 | "github.com/jackc/pgx/v5/pgconn" |
| 13 | "github.com/jackc/pgx/v5/pgtype" |
| 14 | |
| 15 | "github.com/tenseleyFlow/shithub/internal/billing" |
| 16 | "github.com/tenseleyFlow/shithub/internal/entitlements" |
| 17 | "github.com/tenseleyFlow/shithub/internal/orgs" |
| 18 | reposdb "github.com/tenseleyFlow/shithub/internal/repos/sqlc" |
| 19 | "github.com/tenseleyFlow/shithub/internal/testing/dbtest" |
| 20 | usersdb "github.com/tenseleyFlow/shithub/internal/users/sqlc" |
| 21 | ) |
| 22 | |
| 23 | func TestPrivateCollaborationUsageCountsEffectivePrivateAccess(t *testing.T) { |
| 24 | t.Parallel() |
| 25 | ctx := context.Background() |
| 26 | pool := dbtest.NewTestDB(t) |
| 27 | owner := createEntitlementUser(t, pool, "owner") |
| 28 | org, err := orgs.Create(ctx, orgs.Deps{Pool: pool}, orgs.CreateParams{Slug: "acme", CreatedByUserID: owner.ID}) |
| 29 | if err != nil { |
| 30 | t.Fatalf("create org: %v", err) |
| 31 | } |
| 32 | privateRepo := createEntitlementOrgRepo(t, pool, org.ID, "secret", "private") |
| 33 | publicRepo := createEntitlementOrgRepo(t, pool, org.ID, "public", "public") |
| 34 | |
| 35 | plainMember := createEntitlementUser(t, pool, "plain") |
| 36 | insertOrgMember(t, pool, org.ID, plainMember.ID, "member") |
| 37 | |
| 38 | direct := createEntitlementUser(t, pool, "direct") |
| 39 | insertRepoCollaborator(t, pool, privateRepo.ID, direct.ID) |
| 40 | publicOnly := createEntitlementUser(t, pool, "publiconly") |
| 41 | insertRepoCollaborator(t, pool, publicRepo.ID, publicOnly.ID) |
| 42 | |
| 43 | parentTeamID := insertEntitlementTeam(t, pool, org.ID, "platform", 0) |
| 44 | childTeamID := insertEntitlementTeam(t, pool, org.ID, "runtime", parentTeamID) |
| 45 | childMember := createEntitlementUser(t, pool, "childmember") |
| 46 | insertTeamMember(t, pool, childTeamID, childMember.ID) |
| 47 | insertTeamRepoGrant(t, pool, parentTeamID, privateRepo.ID) |
| 48 | |
| 49 | usage, err := entitlements.PrivateCollaborationUsageForOrg(ctx, entitlements.Deps{Pool: pool}, org.ID) |
| 50 | if err != nil { |
| 51 | t.Fatalf("PrivateCollaborationUsageForOrg: %v", err) |
| 52 | } |
| 53 | if usage.Count != 3 { |
| 54 | t.Fatalf("private collaborator count=%d, want owner + direct + inherited team member", usage.Count) |
| 55 | } |
| 56 | if usage.Limit != entitlements.FreePrivateCollaborationLimit || usage.Unlimited { |
| 57 | t.Fatalf("free usage limit = %+v", usage) |
| 58 | } |
| 59 | } |
| 60 | |
| 61 | func TestPrivateCollaborationExpansionEnforcesFreeLimitAndTeamUnlimited(t *testing.T) { |
| 62 | t.Parallel() |
| 63 | ctx := context.Background() |
| 64 | pool := dbtest.NewTestDB(t) |
| 65 | owner := createEntitlementUser(t, pool, "owner") |
| 66 | org, err := orgs.Create(ctx, orgs.Deps{Pool: pool}, orgs.CreateParams{Slug: "acme", CreatedByUserID: owner.ID}) |
| 67 | if err != nil { |
| 68 | t.Fatalf("create org: %v", err) |
| 69 | } |
| 70 | createEntitlementOrgRepo(t, pool, org.ID, "secret", "private") |
| 71 | first := createEntitlementUser(t, pool, "first") |
| 72 | second := createEntitlementUser(t, pool, "second") |
| 73 | third := createEntitlementUser(t, pool, "third") |
| 74 | |
| 75 | check, err := entitlements.CheckPrivateCollaborationExpansion(ctx, entitlements.Deps{Pool: pool}, org.ID, entitlements.PrivateCollaborationExpansion{ |
| 76 | CandidateUserIDs: []int64{first.ID, second.ID}, |
| 77 | }) |
| 78 | if err != nil { |
| 79 | t.Fatalf("allowed expansion: %v", err) |
| 80 | } |
| 81 | if !check.Allowed || check.WouldUse != 3 { |
| 82 | t.Fatalf("two-user free expansion check = %+v, want allowed at limit", check) |
| 83 | } |
| 84 | |
| 85 | check, err = entitlements.CheckPrivateCollaborationExpansion(ctx, entitlements.Deps{Pool: pool}, org.ID, entitlements.PrivateCollaborationExpansion{ |
| 86 | CandidateUserIDs: []int64{first.ID, second.ID, third.ID}, |
| 87 | }) |
| 88 | if err != nil { |
| 89 | t.Fatalf("blocked expansion: %v", err) |
| 90 | } |
| 91 | if check.Allowed || check.WouldUse != 4 || check.Err() != entitlements.ErrPrivateCollaborationLimitExceeded { |
| 92 | t.Fatalf("three-user free expansion check = %+v, want blocked", check) |
| 93 | } |
| 94 | if !strings.Contains(check.Message(), "up to 3 private collaborators") { |
| 95 | t.Fatalf("message=%q, want concrete limit", check.Message()) |
| 96 | } |
| 97 | |
| 98 | now := time.Now().UTC().Truncate(time.Second) |
| 99 | if err := setSubscription(ctx, billing.Deps{Pool: pool}, org.ID, now, billing.PlanTeam, billing.SubscriptionStatusActive, "private-collab"); err != nil { |
| 100 | t.Fatalf("activate team: %v", err) |
| 101 | } |
| 102 | check, err = entitlements.CheckPrivateCollaborationExpansion(ctx, entitlements.Deps{Pool: pool, Now: func() time.Time { return now }}, org.ID, entitlements.PrivateCollaborationExpansion{ |
| 103 | CandidateUserIDs: []int64{first.ID, second.ID, third.ID}, |
| 104 | }) |
| 105 | if err != nil { |
| 106 | t.Fatalf("team expansion: %v", err) |
| 107 | } |
| 108 | if !check.Allowed || !check.Usage.Unlimited { |
| 109 | t.Fatalf("team expansion check = %+v, want unlimited", check) |
| 110 | } |
| 111 | } |
| 112 | |
| 113 | func TestPrivateRepoCreationCountsOwnersForFirstPrivateRepo(t *testing.T) { |
| 114 | t.Parallel() |
| 115 | ctx := context.Background() |
| 116 | pool := dbtest.NewTestDB(t) |
| 117 | owner := createEntitlementUser(t, pool, "owner") |
| 118 | org, err := orgs.Create(ctx, orgs.Deps{Pool: pool}, orgs.CreateParams{Slug: "acme", CreatedByUserID: owner.ID}) |
| 119 | if err != nil { |
| 120 | t.Fatalf("create org: %v", err) |
| 121 | } |
| 122 | for _, name := range []string{"owner2", "owner3", "owner4"} { |
| 123 | insertOrgMember(t, pool, org.ID, createEntitlementUser(t, pool, name).ID, "owner") |
| 124 | } |
| 125 | |
| 126 | check, err := entitlements.CheckPrivateRepositoryCreation(ctx, entitlements.Deps{Pool: pool}, org.ID) |
| 127 | if err != nil { |
| 128 | t.Fatalf("CheckPrivateRepositoryCreation: %v", err) |
| 129 | } |
| 130 | if check.Allowed || check.WouldUse != 4 { |
| 131 | t.Fatalf("first-private-repo check = %+v, want blocked by four owners", check) |
| 132 | } |
| 133 | } |
| 134 | |
| 135 | func TestRepoPrivateVisibilityCountsRepoSpecificGrants(t *testing.T) { |
| 136 | t.Parallel() |
| 137 | ctx := context.Background() |
| 138 | pool := dbtest.NewTestDB(t) |
| 139 | owner := createEntitlementUser(t, pool, "owner") |
| 140 | org, err := orgs.Create(ctx, orgs.Deps{Pool: pool}, orgs.CreateParams{Slug: "acme", CreatedByUserID: owner.ID}) |
| 141 | if err != nil { |
| 142 | t.Fatalf("create org: %v", err) |
| 143 | } |
| 144 | repo := createEntitlementOrgRepo(t, pool, org.ID, "soon-private", "public") |
| 145 | insertRepoCollaborator(t, pool, repo.ID, createEntitlementUser(t, pool, "direct").ID) |
| 146 | teamID := insertEntitlementTeam(t, pool, org.ID, "security", 0) |
| 147 | insertTeamMember(t, pool, teamID, createEntitlementUser(t, pool, "teamuser").ID) |
| 148 | insertTeamRepoGrant(t, pool, teamID, repo.ID) |
| 149 | |
| 150 | check, err := entitlements.CheckRepoPrivateVisibility(ctx, entitlements.Deps{Pool: pool}, org.ID, repo.ID) |
| 151 | if err != nil { |
| 152 | t.Fatalf("CheckRepoPrivateVisibility: %v", err) |
| 153 | } |
| 154 | if !check.Allowed || check.WouldUse != 3 { |
| 155 | t.Fatalf("public-to-private check = %+v, want owner + direct + team user allowed at limit", check) |
| 156 | } |
| 157 | |
| 158 | insertRepoCollaborator(t, pool, repo.ID, createEntitlementUser(t, pool, "extra").ID) |
| 159 | check, err = entitlements.CheckRepoPrivateVisibility(ctx, entitlements.Deps{Pool: pool}, org.ID, repo.ID) |
| 160 | if err != nil { |
| 161 | t.Fatalf("CheckRepoPrivateVisibility after extra: %v", err) |
| 162 | } |
| 163 | if check.Allowed || check.WouldUse != 4 { |
| 164 | t.Fatalf("public-to-private check with extra = %+v, want blocked", check) |
| 165 | } |
| 166 | } |
| 167 | |
| 168 | func createEntitlementUser(t *testing.T, db usersdb.DBTX, username string) usersdb.User { |
| 169 | t.Helper() |
| 170 | user, err := usersdb.New().CreateUser(context.Background(), db, usersdb.CreateUserParams{ |
| 171 | Username: username, |
| 172 | DisplayName: username, |
| 173 | PasswordHash: fixtureHash, |
| 174 | }) |
| 175 | if err != nil { |
| 176 | t.Fatalf("create user %s: %v", username, err) |
| 177 | } |
| 178 | return user |
| 179 | } |
| 180 | |
| 181 | func createEntitlementOrgRepo(t *testing.T, db reposdb.DBTX, orgID int64, name, visibility string) reposdb.Repo { |
| 182 | t.Helper() |
| 183 | repo, err := reposdb.New().CreateRepo(context.Background(), db, reposdb.CreateRepoParams{ |
| 184 | OwnerOrgID: pgtype.Int8{Int64: orgID, Valid: true}, |
| 185 | Name: name, |
| 186 | DefaultBranch: "trunk", |
| 187 | Visibility: reposdb.RepoVisibility(visibility), |
| 188 | }) |
| 189 | if err != nil { |
| 190 | t.Fatalf("create repo %s: %v", name, err) |
| 191 | } |
| 192 | return repo |
| 193 | } |
| 194 | |
| 195 | func insertOrgMember(t *testing.T, db orgsdbtx, orgID, userID int64, role string) { |
| 196 | t.Helper() |
| 197 | if _, err := db.Exec(context.Background(), `INSERT INTO org_members (org_id, user_id, role) VALUES ($1, $2, $3)`, orgID, userID, role); err != nil { |
| 198 | t.Fatalf("insert org member: %v", err) |
| 199 | } |
| 200 | } |
| 201 | |
| 202 | func insertRepoCollaborator(t *testing.T, db orgsdbtx, repoID, userID int64) { |
| 203 | t.Helper() |
| 204 | if _, err := db.Exec(context.Background(), `INSERT INTO repo_collaborators (repo_id, user_id, role) VALUES ($1, $2, 'read')`, repoID, userID); err != nil { |
| 205 | t.Fatalf("insert repo collaborator: %v", err) |
| 206 | } |
| 207 | } |
| 208 | |
| 209 | func insertEntitlementTeam(t *testing.T, db orgsdbtx, orgID int64, slug string, parentTeamID int64) int64 { |
| 210 | t.Helper() |
| 211 | var id int64 |
| 212 | if parentTeamID == 0 { |
| 213 | if err := db.QueryRow(context.Background(), `INSERT INTO teams (org_id, slug, display_name) VALUES ($1, $2, $2) RETURNING id`, orgID, slug).Scan(&id); err != nil { |
| 214 | t.Fatalf("insert team: %v", err) |
| 215 | } |
| 216 | return id |
| 217 | } |
| 218 | if err := db.QueryRow(context.Background(), `INSERT INTO teams (org_id, slug, display_name, parent_team_id) VALUES ($1, $2, $2, $3) RETURNING id`, orgID, slug, parentTeamID).Scan(&id); err != nil { |
| 219 | t.Fatalf("insert child team: %v", err) |
| 220 | } |
| 221 | return id |
| 222 | } |
| 223 | |
| 224 | func insertTeamMember(t *testing.T, db orgsdbtx, teamID, userID int64) { |
| 225 | t.Helper() |
| 226 | if _, err := db.Exec(context.Background(), `INSERT INTO team_members (team_id, user_id, role) VALUES ($1, $2, 'member')`, teamID, userID); err != nil { |
| 227 | t.Fatalf("insert team member: %v", err) |
| 228 | } |
| 229 | } |
| 230 | |
| 231 | func insertTeamRepoGrant(t *testing.T, db orgsdbtx, teamID, repoID int64) { |
| 232 | t.Helper() |
| 233 | if _, err := db.Exec(context.Background(), `INSERT INTO team_repo_access (team_id, repo_id, role) VALUES ($1, $2, 'read')`, teamID, repoID); err != nil { |
| 234 | t.Fatalf("insert team repo grant: %v", err) |
| 235 | } |
| 236 | } |
| 237 | |
| 238 | type orgsdbtx interface { |
| 239 | Exec(context.Context, string, ...any) (pgconn.CommandTag, error) |
| 240 | QueryRow(context.Context, string, ...any) pgx.Row |
| 241 | } |
| 242 |