Go · 3979 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 package entitlements_test
4
5 import (
6 "context"
7 "strconv"
8 "testing"
9 "time"
10
11 "github.com/jackc/pgx/v5/pgtype"
12 "github.com/jackc/pgx/v5/pgxpool"
13
14 "github.com/tenseleyFlow/shithub/internal/billing"
15 billingdb "github.com/tenseleyFlow/shithub/internal/billing/sqlc"
16 "github.com/tenseleyFlow/shithub/internal/entitlements"
17 "github.com/tenseleyFlow/shithub/internal/testing/dbtest"
18 usersdb "github.com/tenseleyFlow/shithub/internal/users/sqlc"
19 )
20
21 // TestProfilePinsCapConstantsAreDefinedForBothPlans locks the PRO01
22 // ratification: LimitProfilePinsFreeCap = 6, LimitProfilePinsProCap = 100.
23 // Both limits report Defined=true regardless of the principal's plan
24 // — the cap is a constant, the *applicable* cap is whichever matches
25 // CanUse(FeatureProfilePinsBeyondFree).
26 func TestProfilePinsCapConstantsAreDefinedForBothPlans(t *testing.T) {
27 t.Parallel()
28 ctx := context.Background()
29 pool, userID := setupEntitlementUser(t, "freepin")
30
31 set, err := entitlements.ForPrincipal(ctx, entitlements.Deps{Pool: pool}, billing.PrincipalForUser(userID))
32 if err != nil {
33 t.Fatalf("ForPrincipal: %v", err)
34 }
35 free, err := set.Limit(entitlements.LimitProfilePinsFreeCap)
36 if err != nil {
37 t.Fatalf("Limit free: %v", err)
38 }
39 if !free.Defined || free.Value != entitlements.FreeProfilePinsCap {
40 t.Errorf("LimitProfilePinsFreeCap = %+v, want defined value 6", free)
41 }
42 pro, err := set.Limit(entitlements.LimitProfilePinsProCap)
43 if err != nil {
44 t.Fatalf("Limit pro: %v", err)
45 }
46 if !pro.Defined || pro.Value != entitlements.ProProfilePinsCap {
47 t.Errorf("LimitProfilePinsProCap = %+v, want defined value 100", pro)
48 }
49 }
50
51 func TestProfilePinsBeyondFreeFreeUserCannotUseFeature(t *testing.T) {
52 t.Parallel()
53 ctx := context.Background()
54 pool, userID := setupEntitlementUser(t, "freecap")
55 set, err := entitlements.ForPrincipal(ctx, entitlements.Deps{Pool: pool}, billing.PrincipalForUser(userID))
56 if err != nil {
57 t.Fatalf("ForPrincipal: %v", err)
58 }
59 decision := set.CanUse(entitlements.FeatureProfilePinsBeyondFree)
60 if decision.Allowed {
61 t.Errorf("Free user should NOT be allowed FeatureProfilePinsBeyondFree, got %+v", decision)
62 }
63 }
64
65 func TestProfilePinsBeyondFreeProUserCanUseFeature(t *testing.T) {
66 t.Parallel()
67 ctx := context.Background()
68 pool, userID := setupEntitlementUser(t, "prouser")
69 now := time.Now().UTC()
70 if err := upgradeUserToPro(ctx, pool, userID, now); err != nil {
71 t.Fatalf("upgrade to pro: %v", err)
72 }
73
74 set, err := entitlements.ForPrincipal(ctx, entitlements.Deps{
75 Pool: pool,
76 Now: func() time.Time { return now },
77 }, billing.PrincipalForUser(userID))
78 if err != nil {
79 t.Fatalf("ForPrincipal: %v", err)
80 }
81 decision := set.CanUse(entitlements.FeatureProfilePinsBeyondFree)
82 if !decision.Allowed {
83 t.Errorf("Pro user should be allowed FeatureProfilePinsBeyondFree, got %+v", decision)
84 }
85 }
86
87 func setupEntitlementUser(t *testing.T, username string) (*pgxpool.Pool, int64) {
88 t.Helper()
89 pool := dbtest.NewTestDB(t)
90 user, err := usersdb.New().CreateUser(context.Background(), pool, usersdb.CreateUserParams{
91 Username: username, DisplayName: username, PasswordHash: fixtureHash,
92 })
93 if err != nil {
94 t.Fatalf("CreateUser: %v", err)
95 }
96 return pool, user.ID
97 }
98
99 func upgradeUserToPro(ctx context.Context, pool *pgxpool.Pool, userID int64, now time.Time) error {
100 suffix := strconv.FormatInt(userID, 10)
101 _, err := billingdb.New().ApplyUserSubscriptionSnapshot(ctx, pool, billingdb.ApplyUserSubscriptionSnapshotParams{
102 UserID: userID,
103 Plan: billingdb.UserPlanPro,
104 SubscriptionStatus: billingdb.BillingSubscriptionStatusActive,
105 StripeSubscriptionID: pgtype.Text{String: "sub_pro_" + suffix, Valid: true},
106 CurrentPeriodStart: pgtype.Timestamptz{Time: now.Add(-time.Hour), Valid: true},
107 CurrentPeriodEnd: pgtype.Timestamptz{Time: now.Add(30 * 24 * time.Hour), Valid: true},
108 LastWebhookEventID: "evt_pro_" + suffix,
109 })
110 return err
111 }
112