| 1 | // SPDX-License-Identifier: AGPL-3.0-or-later |
| 2 | |
| 3 | package entitlements_test |
| 4 | |
| 5 | import ( |
| 6 | "context" |
| 7 | "errors" |
| 8 | "io" |
| 9 | "log/slog" |
| 10 | "net/http" |
| 11 | "strings" |
| 12 | "testing" |
| 13 | "time" |
| 14 | |
| 15 | "github.com/jackc/pgx/v5/pgxpool" |
| 16 | |
| 17 | "github.com/tenseleyFlow/shithub/internal/billing" |
| 18 | "github.com/tenseleyFlow/shithub/internal/entitlements" |
| 19 | "github.com/tenseleyFlow/shithub/internal/orgs" |
| 20 | "github.com/tenseleyFlow/shithub/internal/testing/dbtest" |
| 21 | usersdb "github.com/tenseleyFlow/shithub/internal/users/sqlc" |
| 22 | ) |
| 23 | |
| 24 | const fixtureHash = "$argon2id$v=19$m=16384,t=1,p=1$" + |
| 25 | "AAAAAAAAAAAAAAAA$" + |
| 26 | "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" |
| 27 | |
| 28 | func TestCheckOrgFeature(t *testing.T) { |
| 29 | t.Parallel() |
| 30 | tests := []struct { |
| 31 | name string |
| 32 | mutate func(context.Context, billing.Deps, int64, time.Time) error |
| 33 | now func(time.Time) time.Time |
| 34 | want bool |
| 35 | reason entitlements.Reason |
| 36 | }{ |
| 37 | { |
| 38 | name: "free org requires upgrade", |
| 39 | want: false, |
| 40 | reason: entitlements.ReasonUpgradeRequired, |
| 41 | }, |
| 42 | { |
| 43 | name: "team active allows feature", |
| 44 | mutate: func(ctx context.Context, deps billing.Deps, orgID int64, now time.Time) error { |
| 45 | return setSubscription(ctx, deps, orgID, now, billing.PlanTeam, billing.SubscriptionStatusActive, "active") |
| 46 | }, |
| 47 | want: true, |
| 48 | reason: entitlements.ReasonNone, |
| 49 | }, |
| 50 | { |
| 51 | name: "team trialing allows feature", |
| 52 | mutate: func(ctx context.Context, deps billing.Deps, orgID int64, now time.Time) error { |
| 53 | return setSubscription(ctx, deps, orgID, now, billing.PlanTeam, billing.SubscriptionStatusTrialing, "trialing") |
| 54 | }, |
| 55 | want: true, |
| 56 | reason: entitlements.ReasonNone, |
| 57 | }, |
| 58 | { |
| 59 | name: "team incomplete needs billing action", |
| 60 | mutate: func(ctx context.Context, deps billing.Deps, orgID int64, now time.Time) error { |
| 61 | return setSubscription(ctx, deps, orgID, now, billing.PlanTeam, billing.SubscriptionStatusIncomplete, "incomplete") |
| 62 | }, |
| 63 | want: false, |
| 64 | reason: entitlements.ReasonBillingActionNeeded, |
| 65 | }, |
| 66 | { |
| 67 | name: "team past due within grace still allows feature", |
| 68 | mutate: func(ctx context.Context, deps billing.Deps, orgID int64, now time.Time) error { |
| 69 | if err := setSubscription(ctx, deps, orgID, now, billing.PlanTeam, billing.SubscriptionStatusActive, "grace"); err != nil { |
| 70 | return err |
| 71 | } |
| 72 | _, err := billing.MarkPastDue(ctx, deps, orgID, now.Add(24*time.Hour), "evt_past_due") |
| 73 | return err |
| 74 | }, |
| 75 | now: func(now time.Time) time.Time { return now.Add(12 * time.Hour) }, |
| 76 | want: true, |
| 77 | reason: entitlements.ReasonNone, |
| 78 | }, |
| 79 | { |
| 80 | name: "team past due after grace needs billing action", |
| 81 | mutate: func(ctx context.Context, deps billing.Deps, orgID int64, now time.Time) error { |
| 82 | if err := setSubscription(ctx, deps, orgID, now, billing.PlanTeam, billing.SubscriptionStatusActive, "lapsed"); err != nil { |
| 83 | return err |
| 84 | } |
| 85 | _, err := billing.MarkPastDue(ctx, deps, orgID, now.Add(24*time.Hour), "evt_past_due") |
| 86 | return err |
| 87 | }, |
| 88 | now: func(now time.Time) time.Time { return now.Add(48 * time.Hour) }, |
| 89 | want: false, |
| 90 | reason: entitlements.ReasonBillingActionNeeded, |
| 91 | }, |
| 92 | { |
| 93 | name: "team locked without grace needs billing action", |
| 94 | mutate: func(ctx context.Context, deps billing.Deps, orgID int64, now time.Time) error { |
| 95 | return setSubscription(ctx, deps, orgID, now, billing.PlanTeam, billing.SubscriptionStatusPastDue, "locked") |
| 96 | }, |
| 97 | want: false, |
| 98 | reason: entitlements.ReasonBillingActionNeeded, |
| 99 | }, |
| 100 | { |
| 101 | name: "enterprise stub does not unlock team features", |
| 102 | mutate: func(ctx context.Context, deps billing.Deps, orgID int64, now time.Time) error { |
| 103 | return setSubscription(ctx, deps, orgID, now, billing.PlanEnterprise, billing.SubscriptionStatusActive, "enterprise") |
| 104 | }, |
| 105 | want: false, |
| 106 | reason: entitlements.ReasonEnterpriseContactSales, |
| 107 | }, |
| 108 | } |
| 109 | |
| 110 | for _, tt := range tests { |
| 111 | t.Run(tt.name, func(t *testing.T) { |
| 112 | t.Parallel() |
| 113 | ctx := context.Background() |
| 114 | pool, orgID := setupEntitlementOrg(t) |
| 115 | bdeps := billing.Deps{Pool: pool} |
| 116 | now := time.Now().UTC().Truncate(time.Second) |
| 117 | if tt.mutate != nil { |
| 118 | if err := tt.mutate(ctx, bdeps, orgID, now); err != nil { |
| 119 | t.Fatalf("mutate billing state: %v", err) |
| 120 | } |
| 121 | } |
| 122 | checkNow := now |
| 123 | if tt.now != nil { |
| 124 | checkNow = tt.now(now) |
| 125 | } |
| 126 | decision, err := entitlements.CheckOrgFeature(ctx, entitlements.Deps{ |
| 127 | Pool: pool, |
| 128 | Now: func() time.Time { return checkNow }, |
| 129 | }, orgID, entitlements.FeatureOrgActionsSecrets) |
| 130 | if err != nil { |
| 131 | t.Fatalf("CheckOrgFeature: %v", err) |
| 132 | } |
| 133 | if decision.Allowed != tt.want || decision.Reason != tt.reason { |
| 134 | t.Fatalf("decision = %+v, want allowed=%v reason=%s", decision, tt.want, tt.reason) |
| 135 | } |
| 136 | }) |
| 137 | } |
| 138 | } |
| 139 | |
| 140 | func TestForOrgCanUseAndLimit(t *testing.T) { |
| 141 | t.Parallel() |
| 142 | ctx := context.Background() |
| 143 | pool, orgID := setupEntitlementOrg(t) |
| 144 | now := time.Now().UTC().Truncate(time.Second) |
| 145 | if err := setSubscription(ctx, billing.Deps{Pool: pool}, orgID, now, billing.PlanTeam, billing.SubscriptionStatusActive, "limits"); err != nil { |
| 146 | t.Fatalf("set subscription: %v", err) |
| 147 | } |
| 148 | |
| 149 | set, err := entitlements.ForOrg(ctx, entitlements.Deps{ |
| 150 | Pool: pool, |
| 151 | Now: func() time.Time { return now }, |
| 152 | }, orgID) |
| 153 | if err != nil { |
| 154 | t.Fatalf("ForOrg: %v", err) |
| 155 | } |
| 156 | for _, feature := range []entitlements.Feature{ |
| 157 | entitlements.FeatureOrgSecretTeams, |
| 158 | entitlements.FeatureOrgAdvancedBranchProtection, |
| 159 | entitlements.FeatureOrgRequiredReviewers, |
| 160 | entitlements.FeatureOrgActionsSecrets, |
| 161 | entitlements.FeatureOrgActionsVariables, |
| 162 | entitlements.FeatureOrgPrivateCollaboration, |
| 163 | entitlements.FeatureOrgStorageQuota, |
| 164 | entitlements.FeatureOrgActionsMinutesQuota, |
| 165 | } { |
| 166 | if decision := set.CanUse(feature); !decision.Allowed { |
| 167 | t.Fatalf("feature %s decision=%+v, want allowed", feature, decision) |
| 168 | } |
| 169 | } |
| 170 | collab, err := set.Limit(entitlements.LimitOrgPrivateCollaboration) |
| 171 | if err != nil { |
| 172 | t.Fatalf("Limit private collaboration: %v", err) |
| 173 | } |
| 174 | if !collab.Allowed || !collab.Defined || !collab.Unlimited || collab.Unit != "collaborators" { |
| 175 | t.Fatalf("private collaboration limit = %+v", collab) |
| 176 | } |
| 177 | storage, err := set.Limit(entitlements.LimitOrgStorageQuota) |
| 178 | if err != nil { |
| 179 | t.Fatalf("Limit storage: %v", err) |
| 180 | } |
| 181 | if !storage.Allowed || storage.Defined || storage.Unit != "bytes" { |
| 182 | t.Fatalf("storage limit = %+v, want allowed but deferred concrete quota", storage) |
| 183 | } |
| 184 | } |
| 185 | |
| 186 | func TestUnknownFeatureAndLimit(t *testing.T) { |
| 187 | t.Parallel() |
| 188 | ctx := context.Background() |
| 189 | pool, orgID := setupEntitlementOrg(t) |
| 190 | _, err := entitlements.CheckOrgFeature(ctx, entitlements.Deps{Pool: pool}, orgID, entitlements.Feature("org.mystery")) |
| 191 | if !errors.Is(err, entitlements.ErrUnknownFeature) { |
| 192 | t.Fatalf("CheckOrgFeature unknown err=%v, want ErrUnknownFeature", err) |
| 193 | } |
| 194 | set, err := entitlements.ForOrg(ctx, entitlements.Deps{Pool: pool}, orgID) |
| 195 | if err != nil { |
| 196 | t.Fatalf("ForOrg: %v", err) |
| 197 | } |
| 198 | _, err = set.Limit(entitlements.Limit("org.mystery_limit")) |
| 199 | if !errors.Is(err, entitlements.ErrUnknownLimit) { |
| 200 | t.Fatalf("Limit unknown err=%v, want ErrUnknownLimit", err) |
| 201 | } |
| 202 | } |
| 203 | |
| 204 | func TestDecisionUpgradeBanner(t *testing.T) { |
| 205 | t.Parallel() |
| 206 | decision := entitlements.Decision{ |
| 207 | Feature: entitlements.FeatureOrgSecretTeams, |
| 208 | RequiredPlan: billing.PlanTeam, |
| 209 | Reason: entitlements.ReasonUpgradeRequired, |
| 210 | } |
| 211 | banner := decision.UpgradeBanner("Secret teams", "acme inc") |
| 212 | if banner.StatusCode != http.StatusPaymentRequired { |
| 213 | t.Fatalf("status=%d, want 402", banner.StatusCode) |
| 214 | } |
| 215 | if banner.ActionHref != "/organizations/acme%20inc/settings/billing" { |
| 216 | t.Fatalf("href=%q", banner.ActionHref) |
| 217 | } |
| 218 | if !strings.Contains(banner.Message, "require Team billing") { |
| 219 | t.Fatalf("message=%q", banner.Message) |
| 220 | } |
| 221 | } |
| 222 | |
| 223 | func setSubscription(ctx context.Context, deps billing.Deps, orgID int64, now time.Time, plan billing.Plan, status billing.SubscriptionStatus, suffix string) error { |
| 224 | _, err := billing.ApplySubscriptionSnapshot(ctx, deps, billing.SubscriptionSnapshot{ |
| 225 | OrgID: orgID, |
| 226 | Plan: plan, |
| 227 | Status: status, |
| 228 | StripeSubscriptionID: "sub_" + suffix, |
| 229 | StripeSubscriptionItemID: "si_" + suffix, |
| 230 | CurrentPeriodStart: now, |
| 231 | CurrentPeriodEnd: now.Add(30 * 24 * time.Hour), |
| 232 | LastWebhookEventID: "evt_" + suffix, |
| 233 | }) |
| 234 | return err |
| 235 | } |
| 236 | |
| 237 | func setupEntitlementOrg(t *testing.T) (*pgxpool.Pool, int64) { |
| 238 | t.Helper() |
| 239 | pool := dbtest.NewTestDB(t) |
| 240 | ctx := context.Background() |
| 241 | user, err := usersdb.New().CreateUser(ctx, pool, usersdb.CreateUserParams{ |
| 242 | Username: "owner", DisplayName: "Owner", PasswordHash: fixtureHash, |
| 243 | }) |
| 244 | if err != nil { |
| 245 | t.Fatalf("CreateUser: %v", err) |
| 246 | } |
| 247 | org, err := orgs.Create(ctx, orgs.Deps{ |
| 248 | Pool: pool, |
| 249 | Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), |
| 250 | }, orgs.CreateParams{ |
| 251 | Slug: "acme", DisplayName: "Acme Inc", CreatedByUserID: user.ID, |
| 252 | }) |
| 253 | if err != nil { |
| 254 | t.Fatalf("orgs.Create: %v", err) |
| 255 | } |
| 256 | return pool, org.ID |
| 257 | } |
| 258 |