Go · 7012 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 package billing_test
4
5 import (
6 "context"
7 "testing"
8 "time"
9
10 "github.com/jackc/pgx/v5/pgtype"
11
12 "github.com/tenseleyFlow/shithub/internal/billing"
13 billingdb "github.com/tenseleyFlow/shithub/internal/billing/sqlc"
14 )
15
16 // PRO08 D1 — lock-column preservation across snapshot applies.
17 //
18 // Stripe delivers `customer.subscription.updated[status=past_due]` and
19 // `invoice.payment_failed` in unguaranteed order. Pre-PRO08, the
20 // snapshot CTE unconditionally NULLed locked_at/grace_until/lock_reason,
21 // so a subscription.updated arriving AFTER invoice.payment_failed
22 // wiped the grace lock — the user would see the "Pro features are
23 // read-only" banner immediately instead of staying in grace.
24 //
25 // These tests pin the fix: the snapshot path preserves existing lock
26 // columns when status=past_due, and clears them ONLY when transitioning
27 // from past_due/unpaid/incomplete → active/trialing (the recovery path
28 // that MarkPaymentSucceeded normally drives, but subscription.updated
29 // status flips also need to handle it).
30
31 func TestApplySubscriptionSnapshotPreservesGraceLockOnPastDue(t *testing.T) {
32 _, deps, org := setup(t)
33 ctx := context.Background()
34
35 // 1. Establish past_due via the invoice path — MarkPastDue sets
36 // locked_at + grace_until.
37 graceUntil := time.Now().UTC().Add(72 * time.Hour).Truncate(time.Second)
38 if _, err := billing.MarkPastDue(ctx, deps, org.ID, graceUntil, "evt_past_due"); err != nil {
39 t.Fatalf("MarkPastDue: %v", err)
40 }
41 before, err := billing.GetOrgBillingState(ctx, deps, org.ID)
42 if err != nil {
43 t.Fatalf("get before: %v", err)
44 }
45 if !before.LockedAt.Valid || !before.GraceUntil.Valid {
46 t.Fatalf("expected locked + grace set after MarkPastDue, got %+v", before)
47 }
48
49 // 2. customer.subscription.updated[status=past_due] arrives next.
50 // Pre-fix: this wiped the lock. Post-fix: COALESCE preserves it.
51 _, err = billing.ApplySubscriptionSnapshot(ctx, deps, billing.SubscriptionSnapshot{
52 OrgID: org.ID,
53 Plan: billing.PlanTeam,
54 Status: billing.SubscriptionStatusPastDue,
55 StripeSubscriptionID: "sub_past_due",
56 LastWebhookEventID: "evt_sub_past_due",
57 })
58 if err != nil {
59 t.Fatalf("ApplySubscriptionSnapshot: %v", err)
60 }
61 after, err := billing.GetOrgBillingState(ctx, deps, org.ID)
62 if err != nil {
63 t.Fatalf("get after: %v", err)
64 }
65 if !after.LockedAt.Valid {
66 t.Fatalf("locked_at lost on subscription.updated[past_due]: %+v", after)
67 }
68 if !after.GraceUntil.Valid || !after.GraceUntil.Time.Equal(graceUntil) {
69 t.Fatalf("grace_until lost on subscription.updated[past_due]: got %+v want %v", after.GraceUntil, graceUntil)
70 }
71 if !after.LockReason.Valid || after.LockReason.BillingLockReason != billingdb.BillingLockReasonPastDue {
72 t.Fatalf("lock_reason lost: %+v", after.LockReason)
73 }
74 }
75
76 // TestApplySubscriptionSnapshotClearsLocksOnPastDueRecovery covers the
77 // other half: subscription.updated[status=active] arriving after a
78 // past_due episode clears the lock columns even without an explicit
79 // invoice.payment_succeeded.
80 func TestApplySubscriptionSnapshotClearsLocksOnPastDueRecovery(t *testing.T) {
81 _, deps, org := setup(t)
82 ctx := context.Background()
83
84 graceUntil := time.Now().UTC().Add(72 * time.Hour)
85 if _, err := billing.MarkPastDue(ctx, deps, org.ID, graceUntil, "evt_past_due"); err != nil {
86 t.Fatalf("MarkPastDue: %v", err)
87 }
88 // recovery via subscription.updated[active]:
89 if _, err := billing.ApplySubscriptionSnapshot(ctx, deps, billing.SubscriptionSnapshot{
90 OrgID: org.ID,
91 Plan: billing.PlanTeam,
92 Status: billing.SubscriptionStatusActive,
93 StripeSubscriptionID: "sub_recovered",
94 LastWebhookEventID: "evt_sub_active",
95 }); err != nil {
96 t.Fatalf("ApplySubscriptionSnapshot active: %v", err)
97 }
98 state, err := billing.GetOrgBillingState(ctx, deps, org.ID)
99 if err != nil {
100 t.Fatalf("get: %v", err)
101 }
102 if state.LockedAt.Valid {
103 t.Errorf("locked_at should be NULL after recovery, got %+v", state.LockedAt)
104 }
105 if state.GraceUntil.Valid {
106 t.Errorf("grace_until should be NULL after recovery, got %+v", state.GraceUntil)
107 }
108 if state.LockReason.Valid {
109 t.Errorf("lock_reason should be NULL after recovery, got %+v", state.LockReason)
110 }
111 }
112
113 // TestApplyUserSubscriptionSnapshotPreservesGraceLockOnPastDue mirrors
114 // the org-side test for the user-tier (Pro). Same fix shape, same
115 // risk: subscription.updated for a Pro user in grace must not wipe
116 // the grace fields.
117 func TestApplyUserSubscriptionSnapshotPreservesGraceLockOnPastDue(t *testing.T) {
118 _, deps, org := setup(t)
119 ctx := context.Background()
120 _ = org
121
122 // Create a fresh user — setup gives us an owner via insertSetupUser.
123 userID := insertSetupUser(t, deps, "prolocked")
124 q := billingdb.New()
125 // Seed Pro+active first (the trigger seeds 'free,none' by default).
126 if _, err := q.ApplyUserSubscriptionSnapshot(ctx, deps.Pool, billingdb.ApplyUserSubscriptionSnapshotParams{
127 UserID: userID,
128 Plan: billingdb.UserPlanPro,
129 SubscriptionStatus: billingdb.BillingSubscriptionStatusActive,
130 StripeSubscriptionID: pgtype.Text{String: "sub_user_lock", Valid: true},
131 LastWebhookEventID: "evt_seed_pro",
132 }); err != nil {
133 t.Fatalf("seed Pro: %v", err)
134 }
135 graceUntil := time.Now().UTC().Add(72 * time.Hour).Truncate(time.Second)
136 if _, err := q.MarkUserPastDue(ctx, deps.Pool, billingdb.MarkUserPastDueParams{
137 UserID: userID,
138 GraceUntil: pgtype.Timestamptz{Time: graceUntil, Valid: true},
139 LastWebhookEventID: "evt_user_past_due",
140 }); err != nil {
141 t.Fatalf("MarkUserPastDue: %v", err)
142 }
143 // Now apply subscription.updated[past_due] via the snapshot path.
144 if _, err := q.ApplyUserSubscriptionSnapshot(ctx, deps.Pool, billingdb.ApplyUserSubscriptionSnapshotParams{
145 UserID: userID,
146 Plan: billingdb.UserPlanPro,
147 SubscriptionStatus: billingdb.BillingSubscriptionStatusPastDue,
148 StripeSubscriptionID: pgtype.Text{String: "sub_user_lock", Valid: true},
149 LastWebhookEventID: "evt_user_sub_past_due",
150 }); err != nil {
151 t.Fatalf("apply past_due via snapshot: %v", err)
152 }
153 state, err := q.GetUserBillingState(ctx, deps.Pool, userID)
154 if err != nil {
155 t.Fatalf("get user state: %v", err)
156 }
157 if !state.LockedAt.Valid {
158 t.Errorf("user locked_at wiped by snapshot path: %+v", state)
159 }
160 if !state.GraceUntil.Valid || !state.GraceUntil.Time.Equal(graceUntil) {
161 t.Errorf("user grace_until wiped/changed: got %+v want %v", state.GraceUntil, graceUntil)
162 }
163 }
164
165 func insertSetupUser(t *testing.T, deps billing.Deps, username string) int64 {
166 t.Helper()
167 var id int64
168 if err := deps.Pool.QueryRow(context.Background(),
169 `INSERT INTO users (username, password_hash) VALUES ($1, $2) RETURNING id`,
170 username,
171 "$argon2id$v=19$m=16384,t=1,p=1$AAAAAAAAAAAAAAAA$AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA",
172 ).Scan(&id); err != nil {
173 t.Fatalf("insert user %s: %v", username, err)
174 }
175 return id
176 }
177