Go · 6370 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 package billing_test
4
5 import (
6 "context"
7 "testing"
8 "time"
9
10 "github.com/jackc/pgx/v5/pgtype"
11 "github.com/jackc/pgx/v5/pgxpool"
12
13 billingdb "github.com/tenseleyFlow/shithub/internal/billing/sqlc"
14 "github.com/tenseleyFlow/shithub/internal/testing/dbtest"
15 usersdb "github.com/tenseleyFlow/shithub/internal/users/sqlc"
16 )
17
18 // setupUser mirrors `setup` but returns just the user — PRO03 user
19 // billing tests don't need an org. The schema's AFTER-INSERT trigger
20 // from 0073 seeds a `free` row in `user_billing_states` automatically.
21 func setupUser(t *testing.T, username string) (int64, *pgxpool.Pool) {
22 t.Helper()
23 p := dbtest.NewTestDB(t)
24 u, err := usersdb.New().CreateUser(context.Background(), p, usersdb.CreateUserParams{
25 Username: username, DisplayName: username, PasswordHash: fixtureHash,
26 })
27 if err != nil {
28 t.Fatalf("create user: %v", err)
29 }
30 return u.ID, p
31 }
32
33 func TestUserBillingState_SeedTriggerFires(t *testing.T) {
34 uID, pool := setupUser(t, "alice")
35 state, err := billingdb.New().GetUserBillingState(context.Background(), pool, uID)
36 if err != nil {
37 t.Fatalf("GetUserBillingState: %v", err)
38 }
39 if state.Plan != billingdb.UserPlanFree {
40 t.Fatalf("seed plan: got %s, want free", state.Plan)
41 }
42 if state.SubscriptionStatus != billingdb.BillingSubscriptionStatusNone {
43 t.Fatalf("seed status: got %s, want none", state.SubscriptionStatus)
44 }
45 if state.UserID != uID {
46 t.Fatalf("seed user_id: got %d, want %d", state.UserID, uID)
47 }
48 }
49
50 func TestUserBillingState_TransitionsFreeToProToPastDueToActiveToCanceled(t *testing.T) {
51 uID, pool := setupUser(t, "alice")
52 q := billingdb.New()
53 ctx := context.Background()
54 start := time.Now().UTC().Truncate(time.Second)
55
56 // free → pro (ApplyUserSubscriptionSnapshot active).
57 applied, err := q.ApplyUserSubscriptionSnapshot(ctx, pool, billingdb.ApplyUserSubscriptionSnapshotParams{
58 UserID: uID,
59 Plan: billingdb.UserPlanPro,
60 SubscriptionStatus: billingdb.BillingSubscriptionStatusActive,
61 StripeSubscriptionID: pgtype.Text{String: "sub_user_test", Valid: true},
62 StripeSubscriptionItemID: pgtype.Text{String: "si_user_test", Valid: true},
63 CurrentPeriodStart: pgtype.Timestamptz{Time: start, Valid: true},
64 CurrentPeriodEnd: pgtype.Timestamptz{Time: start.Add(30 * 24 * time.Hour), Valid: true},
65 CancelAtPeriodEnd: false,
66 LastWebhookEventID: "evt_user_active",
67 })
68 if err != nil {
69 t.Fatalf("Apply free→pro: %v", err)
70 }
71 if applied.Plan != billingdb.UserPlanPro {
72 t.Errorf("post-apply plan: got %s, want pro", applied.Plan)
73 }
74 if applied.SubscriptionStatus != billingdb.BillingSubscriptionStatusActive {
75 t.Errorf("post-apply status: got %s, want active", applied.SubscriptionStatus)
76 }
77 // users.plan must mirror.
78 user, _ := usersdb.New().GetUserByID(ctx, pool, uID)
79 if user.Plan != usersdb.UserPlanPro {
80 t.Errorf("users.plan after pro apply: got %s, want pro", user.Plan)
81 }
82
83 // pro → past_due (MarkUserPastDue returns UserBillingState — no CTE).
84 pastDue, err := q.MarkUserPastDue(ctx, pool, billingdb.MarkUserPastDueParams{
85 UserID: uID,
86 GraceUntil: pgtype.Timestamptz{Time: start.Add(7 * 24 * time.Hour), Valid: true},
87 LastWebhookEventID: "evt_user_past_due",
88 })
89 if err != nil {
90 t.Fatalf("MarkUserPastDue: %v", err)
91 }
92 if pastDue.SubscriptionStatus != billingdb.BillingSubscriptionStatusPastDue {
93 t.Errorf("past_due status: %s", pastDue.SubscriptionStatus)
94 }
95 if !pastDue.LockedAt.Valid {
96 t.Errorf("past_due locked_at not set")
97 }
98 if pastDue.LockReason.BillingLockReason != billingdb.BillingLockReasonPastDue {
99 t.Errorf("past_due lock_reason: %s", pastDue.LockReason.BillingLockReason)
100 }
101
102 // past_due → active (MarkUserPaymentSucceeded).
103 recovered, err := q.MarkUserPaymentSucceeded(ctx, pool, billingdb.MarkUserPaymentSucceededParams{
104 UserID: uID,
105 LastWebhookEventID: "evt_user_recovered",
106 })
107 if err != nil {
108 t.Fatalf("MarkUserPaymentSucceeded: %v", err)
109 }
110 if recovered.SubscriptionStatus != billingdb.BillingSubscriptionStatusActive {
111 t.Errorf("recovered status: %s", recovered.SubscriptionStatus)
112 }
113 if recovered.Plan != billingdb.UserPlanPro {
114 t.Errorf("recovered plan: %s", recovered.Plan)
115 }
116 if recovered.LockedAt.Valid {
117 t.Errorf("recovered locked_at: should be cleared")
118 }
119
120 // active → canceled (MarkUserCanceled).
121 canceled, err := q.MarkUserCanceled(ctx, pool, billingdb.MarkUserCanceledParams{
122 UserID: uID,
123 LastWebhookEventID: "evt_user_canceled",
124 })
125 if err != nil {
126 t.Fatalf("MarkUserCanceled: %v", err)
127 }
128 if canceled.SubscriptionStatus != billingdb.BillingSubscriptionStatusCanceled {
129 t.Errorf("canceled status: %s", canceled.SubscriptionStatus)
130 }
131 if canceled.Plan != billingdb.UserPlanFree {
132 t.Errorf("canceled plan: got %s, want free", canceled.Plan)
133 }
134 user, _ = usersdb.New().GetUserByID(ctx, pool, uID)
135 if user.Plan != usersdb.UserPlanFree {
136 t.Errorf("users.plan after cancel: got %s, want free", user.Plan)
137 }
138
139 // canceled → unlocked + status none (ClearUserBillingLock).
140 cleared, err := q.ClearUserBillingLock(ctx, pool, uID)
141 if err != nil {
142 t.Fatalf("ClearUserBillingLock: %v", err)
143 }
144 if cleared.SubscriptionStatus != billingdb.BillingSubscriptionStatusNone {
145 t.Errorf("cleared status: got %s, want none", cleared.SubscriptionStatus)
146 }
147 if cleared.LockedAt.Valid {
148 t.Errorf("cleared locked_at: should be NULL")
149 }
150 }
151
152 func TestUserBillingState_SetStripeCustomer(t *testing.T) {
153 uID, pool := setupUser(t, "alice")
154 state, err := billingdb.New().SetUserStripeCustomer(context.Background(), pool, billingdb.SetUserStripeCustomerParams{
155 UserID: uID,
156 StripeCustomerID: pgtype.Text{String: "cus_user_test", Valid: true},
157 })
158 if err != nil {
159 t.Fatalf("SetUserStripeCustomer: %v", err)
160 }
161 if !state.StripeCustomerID.Valid || state.StripeCustomerID.String != "cus_user_test" {
162 t.Fatalf("stripe customer: %+v", state.StripeCustomerID)
163 }
164
165 // Lookup by customer id.
166 got, err := billingdb.New().GetUserBillingStateByStripeCustomer(context.Background(), pool, pgtype.Text{String: "cus_user_test", Valid: true})
167 if err != nil {
168 t.Fatalf("GetUserBillingStateByStripeCustomer: %v", err)
169 }
170 if got.UserID != uID {
171 t.Errorf("lookup user_id: got %d, want %d", got.UserID, uID)
172 }
173 }
174