@@ -0,0 +1,176 @@ |
| | 1 | +// SPDX-License-Identifier: AGPL-3.0-or-later |
| | 2 | + |
| | 3 | +package billing_test |
| | 4 | + |
| | 5 | +import ( |
| | 6 | + "context" |
| | 7 | + "testing" |
| | 8 | + "time" |
| | 9 | + |
| | 10 | + "github.com/jackc/pgx/v5/pgtype" |
| | 11 | + |
| | 12 | + "github.com/tenseleyFlow/shithub/internal/billing" |
| | 13 | + billingdb "github.com/tenseleyFlow/shithub/internal/billing/sqlc" |
| | 14 | +) |
| | 15 | + |
| | 16 | +// PRO08 D1 — lock-column preservation across snapshot applies. |
| | 17 | +// |
| | 18 | +// Stripe delivers `customer.subscription.updated[status=past_due]` and |
| | 19 | +// `invoice.payment_failed` in unguaranteed order. Pre-PRO08, the |
| | 20 | +// snapshot CTE unconditionally NULLed locked_at/grace_until/lock_reason, |
| | 21 | +// so a subscription.updated arriving AFTER invoice.payment_failed |
| | 22 | +// wiped the grace lock — the user would see the "Pro features are |
| | 23 | +// read-only" banner immediately instead of staying in grace. |
| | 24 | +// |
| | 25 | +// These tests pin the fix: the snapshot path preserves existing lock |
| | 26 | +// columns when status=past_due, and clears them ONLY when transitioning |
| | 27 | +// from past_due/unpaid/incomplete → active/trialing (the recovery path |
| | 28 | +// that MarkPaymentSucceeded normally drives, but subscription.updated |
| | 29 | +// status flips also need to handle it). |
| | 30 | + |
| | 31 | +func TestApplySubscriptionSnapshotPreservesGraceLockOnPastDue(t *testing.T) { |
| | 32 | + _, deps, org := setup(t) |
| | 33 | + ctx := context.Background() |
| | 34 | + |
| | 35 | + // 1. Establish past_due via the invoice path — MarkPastDue sets |
| | 36 | + // locked_at + grace_until. |
| | 37 | + graceUntil := time.Now().UTC().Add(72 * time.Hour).Truncate(time.Second) |
| | 38 | + if _, err := billing.MarkPastDue(ctx, deps, org.ID, graceUntil, "evt_past_due"); err != nil { |
| | 39 | + t.Fatalf("MarkPastDue: %v", err) |
| | 40 | + } |
| | 41 | + before, err := billing.GetOrgBillingState(ctx, deps, org.ID) |
| | 42 | + if err != nil { |
| | 43 | + t.Fatalf("get before: %v", err) |
| | 44 | + } |
| | 45 | + if !before.LockedAt.Valid || !before.GraceUntil.Valid { |
| | 46 | + t.Fatalf("expected locked + grace set after MarkPastDue, got %+v", before) |
| | 47 | + } |
| | 48 | + |
| | 49 | + // 2. customer.subscription.updated[status=past_due] arrives next. |
| | 50 | + // Pre-fix: this wiped the lock. Post-fix: COALESCE preserves it. |
| | 51 | + _, err = billing.ApplySubscriptionSnapshot(ctx, deps, billing.SubscriptionSnapshot{ |
| | 52 | + OrgID: org.ID, |
| | 53 | + Plan: billing.PlanTeam, |
| | 54 | + Status: billing.SubscriptionStatusPastDue, |
| | 55 | + StripeSubscriptionID: "sub_past_due", |
| | 56 | + LastWebhookEventID: "evt_sub_past_due", |
| | 57 | + }) |
| | 58 | + if err != nil { |
| | 59 | + t.Fatalf("ApplySubscriptionSnapshot: %v", err) |
| | 60 | + } |
| | 61 | + after, err := billing.GetOrgBillingState(ctx, deps, org.ID) |
| | 62 | + if err != nil { |
| | 63 | + t.Fatalf("get after: %v", err) |
| | 64 | + } |
| | 65 | + if !after.LockedAt.Valid { |
| | 66 | + t.Fatalf("locked_at lost on subscription.updated[past_due]: %+v", after) |
| | 67 | + } |
| | 68 | + if !after.GraceUntil.Valid || !after.GraceUntil.Time.Equal(graceUntil) { |
| | 69 | + t.Fatalf("grace_until lost on subscription.updated[past_due]: got %+v want %v", after.GraceUntil, graceUntil) |
| | 70 | + } |
| | 71 | + if !after.LockReason.Valid || after.LockReason.BillingLockReason != billingdb.BillingLockReasonPastDue { |
| | 72 | + t.Fatalf("lock_reason lost: %+v", after.LockReason) |
| | 73 | + } |
| | 74 | +} |
| | 75 | + |
| | 76 | +// TestApplySubscriptionSnapshotClearsLocksOnPastDueRecovery covers the |
| | 77 | +// other half: subscription.updated[status=active] arriving after a |
| | 78 | +// past_due episode clears the lock columns even without an explicit |
| | 79 | +// invoice.payment_succeeded. |
| | 80 | +func TestApplySubscriptionSnapshotClearsLocksOnPastDueRecovery(t *testing.T) { |
| | 81 | + _, deps, org := setup(t) |
| | 82 | + ctx := context.Background() |
| | 83 | + |
| | 84 | + graceUntil := time.Now().UTC().Add(72 * time.Hour) |
| | 85 | + if _, err := billing.MarkPastDue(ctx, deps, org.ID, graceUntil, "evt_past_due"); err != nil { |
| | 86 | + t.Fatalf("MarkPastDue: %v", err) |
| | 87 | + } |
| | 88 | + // recovery via subscription.updated[active]: |
| | 89 | + if _, err := billing.ApplySubscriptionSnapshot(ctx, deps, billing.SubscriptionSnapshot{ |
| | 90 | + OrgID: org.ID, |
| | 91 | + Plan: billing.PlanTeam, |
| | 92 | + Status: billing.SubscriptionStatusActive, |
| | 93 | + StripeSubscriptionID: "sub_recovered", |
| | 94 | + LastWebhookEventID: "evt_sub_active", |
| | 95 | + }); err != nil { |
| | 96 | + t.Fatalf("ApplySubscriptionSnapshot active: %v", err) |
| | 97 | + } |
| | 98 | + state, err := billing.GetOrgBillingState(ctx, deps, org.ID) |
| | 99 | + if err != nil { |
| | 100 | + t.Fatalf("get: %v", err) |
| | 101 | + } |
| | 102 | + if state.LockedAt.Valid { |
| | 103 | + t.Errorf("locked_at should be NULL after recovery, got %+v", state.LockedAt) |
| | 104 | + } |
| | 105 | + if state.GraceUntil.Valid { |
| | 106 | + t.Errorf("grace_until should be NULL after recovery, got %+v", state.GraceUntil) |
| | 107 | + } |
| | 108 | + if state.LockReason.Valid { |
| | 109 | + t.Errorf("lock_reason should be NULL after recovery, got %+v", state.LockReason) |
| | 110 | + } |
| | 111 | +} |
| | 112 | + |
| | 113 | +// TestApplyUserSubscriptionSnapshotPreservesGraceLockOnPastDue mirrors |
| | 114 | +// the org-side test for the user-tier (Pro). Same fix shape, same |
| | 115 | +// risk: subscription.updated for a Pro user in grace must not wipe |
| | 116 | +// the grace fields. |
| | 117 | +func TestApplyUserSubscriptionSnapshotPreservesGraceLockOnPastDue(t *testing.T) { |
| | 118 | + _, deps, org := setup(t) |
| | 119 | + ctx := context.Background() |
| | 120 | + _ = org |
| | 121 | + |
| | 122 | + // Create a fresh user — setup gives us an owner via insertSetupUser. |
| | 123 | + userID := insertSetupUser(t, deps, "prolocked") |
| | 124 | + q := billingdb.New() |
| | 125 | + // Seed Pro+active first (the trigger seeds 'free,none' by default). |
| | 126 | + if _, err := q.ApplyUserSubscriptionSnapshot(ctx, deps.Pool, billingdb.ApplyUserSubscriptionSnapshotParams{ |
| | 127 | + UserID: userID, |
| | 128 | + Plan: billingdb.UserPlanPro, |
| | 129 | + SubscriptionStatus: billingdb.BillingSubscriptionStatusActive, |
| | 130 | + StripeSubscriptionID: pgtype.Text{String: "sub_user_lock", Valid: true}, |
| | 131 | + LastWebhookEventID: "evt_seed_pro", |
| | 132 | + }); err != nil { |
| | 133 | + t.Fatalf("seed Pro: %v", err) |
| | 134 | + } |
| | 135 | + graceUntil := time.Now().UTC().Add(72 * time.Hour).Truncate(time.Second) |
| | 136 | + if _, err := q.MarkUserPastDue(ctx, deps.Pool, billingdb.MarkUserPastDueParams{ |
| | 137 | + UserID: userID, |
| | 138 | + GraceUntil: pgtype.Timestamptz{Time: graceUntil, Valid: true}, |
| | 139 | + LastWebhookEventID: "evt_user_past_due", |
| | 140 | + }); err != nil { |
| | 141 | + t.Fatalf("MarkUserPastDue: %v", err) |
| | 142 | + } |
| | 143 | + // Now apply subscription.updated[past_due] via the snapshot path. |
| | 144 | + if _, err := q.ApplyUserSubscriptionSnapshot(ctx, deps.Pool, billingdb.ApplyUserSubscriptionSnapshotParams{ |
| | 145 | + UserID: userID, |
| | 146 | + Plan: billingdb.UserPlanPro, |
| | 147 | + SubscriptionStatus: billingdb.BillingSubscriptionStatusPastDue, |
| | 148 | + StripeSubscriptionID: pgtype.Text{String: "sub_user_lock", Valid: true}, |
| | 149 | + LastWebhookEventID: "evt_user_sub_past_due", |
| | 150 | + }); err != nil { |
| | 151 | + t.Fatalf("apply past_due via snapshot: %v", err) |
| | 152 | + } |
| | 153 | + state, err := q.GetUserBillingState(ctx, deps.Pool, userID) |
| | 154 | + if err != nil { |
| | 155 | + t.Fatalf("get user state: %v", err) |
| | 156 | + } |
| | 157 | + if !state.LockedAt.Valid { |
| | 158 | + t.Errorf("user locked_at wiped by snapshot path: %+v", state) |
| | 159 | + } |
| | 160 | + if !state.GraceUntil.Valid || !state.GraceUntil.Time.Equal(graceUntil) { |
| | 161 | + t.Errorf("user grace_until wiped/changed: got %+v want %v", state.GraceUntil, graceUntil) |
| | 162 | + } |
| | 163 | +} |
| | 164 | + |
| | 165 | +func insertSetupUser(t *testing.T, deps billing.Deps, username string) int64 { |
| | 166 | + t.Helper() |
| | 167 | + var id int64 |
| | 168 | + if err := deps.Pool.QueryRow(context.Background(), |
| | 169 | + `INSERT INTO users (username, password_hash) VALUES ($1, $2) RETURNING id`, |
| | 170 | + username, |
| | 171 | + "$argon2id$v=19$m=16384,t=1,p=1$AAAAAAAAAAAAAAAA$AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA", |
| | 172 | + ).Scan(&id); err != nil { |
| | 173 | + t.Fatalf("insert user %s: %v", username, err) |
| | 174 | + } |
| | 175 | + return id |
| | 176 | +} |