Go · 7112 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 package billing_test
4
5 import (
6 "context"
7 "io"
8 "log/slog"
9 "testing"
10 "time"
11
12 "github.com/jackc/pgx/v5/pgxpool"
13
14 "github.com/tenseleyFlow/shithub/internal/billing"
15 billingdb "github.com/tenseleyFlow/shithub/internal/billing/sqlc"
16 "github.com/tenseleyFlow/shithub/internal/orgs"
17 orgsdb "github.com/tenseleyFlow/shithub/internal/orgs/sqlc"
18 "github.com/tenseleyFlow/shithub/internal/testing/dbtest"
19 usersdb "github.com/tenseleyFlow/shithub/internal/users/sqlc"
20 )
21
22 const fixtureHash = "$argon2id$v=19$m=16384,t=1,p=1$" +
23 "AAAAAAAAAAAAAAAA$" +
24 "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"
25
26 func setup(t *testing.T) (*pgxpool.Pool, billing.Deps, orgsdb.Org) {
27 t.Helper()
28 pool := dbtest.NewTestDB(t)
29 ctx := context.Background()
30 u, err := usersdb.New().CreateUser(ctx, pool, usersdb.CreateUserParams{
31 Username: "alice", DisplayName: "Alice", PasswordHash: fixtureHash,
32 })
33 if err != nil {
34 t.Fatalf("create user: %v", err)
35 }
36 odeps := orgs.Deps{
37 Pool: pool,
38 Logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
39 }
40 org, err := orgs.Create(ctx, odeps, orgs.CreateParams{
41 Slug: "acme", DisplayName: "Acme Inc", CreatedByUserID: u.ID,
42 })
43 if err != nil {
44 t.Fatalf("create org: %v", err)
45 }
46 return pool, billing.Deps{Pool: pool}, org
47 }
48
49 func TestBillingStateTransitions(t *testing.T) {
50 pool, deps, org := setup(t)
51 ctx := context.Background()
52
53 state, err := billing.GetOrgBillingState(ctx, deps, org.ID)
54 if err != nil {
55 t.Fatalf("GetOrgBillingState: %v", err)
56 }
57 if state.Plan != billing.PlanFree || state.SubscriptionStatus != billing.SubscriptionStatusNone {
58 t.Fatalf("new org state: plan=%s status=%s", state.Plan, state.SubscriptionStatus)
59 }
60
61 state, err = billing.SetStripeCustomer(ctx, deps, org.ID, "cus_test")
62 if err != nil {
63 t.Fatalf("SetStripeCustomer: %v", err)
64 }
65 if !state.StripeCustomerID.Valid || state.StripeCustomerID.String != "cus_test" {
66 t.Fatalf("stripe customer not set: %+v", state.StripeCustomerID)
67 }
68
69 start := time.Now().UTC().Truncate(time.Second)
70 state, err = billing.ApplySubscriptionSnapshot(ctx, deps, billing.SubscriptionSnapshot{
71 OrgID: org.ID,
72 Plan: billing.PlanTeam,
73 Status: billing.SubscriptionStatusActive,
74 StripeSubscriptionID: "sub_test",
75 StripeSubscriptionItemID: "si_test",
76 CurrentPeriodStart: start,
77 CurrentPeriodEnd: start.Add(30 * 24 * time.Hour),
78 LastWebhookEventID: "evt_active",
79 })
80 if err != nil {
81 t.Fatalf("ApplySubscriptionSnapshot active: %v", err)
82 }
83 assertState(t, state, billing.PlanTeam, billing.SubscriptionStatusActive)
84 if state.LockedAt.Valid || state.LockReason.Valid {
85 t.Fatalf("active subscription should not be locked: %+v", state)
86 }
87 assertOrgPlan(t, pool, org.ID, orgsdb.OrgPlanTeam)
88
89 grace := start.Add(7 * 24 * time.Hour)
90 state, err = billing.MarkPastDue(ctx, deps, org.ID, grace, "evt_past_due")
91 if err != nil {
92 t.Fatalf("MarkPastDue: %v", err)
93 }
94 assertState(t, state, billing.PlanTeam, billing.SubscriptionStatusPastDue)
95 if !state.LockedAt.Valid || !state.LockReason.Valid || state.LockReason.BillingLockReason != billingdb.BillingLockReasonPastDue {
96 t.Fatalf("past_due should set lock fields: %+v", state)
97 }
98 if !state.GraceUntil.Valid {
99 t.Fatalf("past_due should set grace_until")
100 }
101
102 state, err = billing.ApplySubscriptionSnapshot(ctx, deps, billing.SubscriptionSnapshot{
103 OrgID: org.ID,
104 Plan: billing.PlanTeam,
105 Status: billing.SubscriptionStatusActive,
106 StripeSubscriptionID: "sub_test",
107 StripeSubscriptionItemID: "si_test",
108 CurrentPeriodStart: start,
109 CurrentPeriodEnd: start.Add(30 * 24 * time.Hour),
110 LastWebhookEventID: "evt_recovered",
111 })
112 if err != nil {
113 t.Fatalf("ApplySubscriptionSnapshot recovered: %v", err)
114 }
115 assertState(t, state, billing.PlanTeam, billing.SubscriptionStatusActive)
116 if state.LockedAt.Valid || state.LockReason.Valid || state.GraceUntil.Valid || state.PastDueAt.Valid {
117 t.Fatalf("recovered subscription should clear lock/grace/past_due: %+v", state)
118 }
119
120 state, err = billing.MarkCanceled(ctx, deps, org.ID, "evt_canceled")
121 if err != nil {
122 t.Fatalf("MarkCanceled: %v", err)
123 }
124 assertState(t, state, billing.PlanFree, billing.SubscriptionStatusCanceled)
125 if !state.LockedAt.Valid || !state.LockReason.Valid || state.LockReason.BillingLockReason != billingdb.BillingLockReasonCanceled {
126 t.Fatalf("canceled subscription should set canceled lock: %+v", state)
127 }
128 assertOrgPlan(t, pool, org.ID, orgsdb.OrgPlanFree)
129
130 state, err = billing.ClearBillingLock(ctx, deps, org.ID)
131 if err != nil {
132 t.Fatalf("ClearBillingLock: %v", err)
133 }
134 assertState(t, state, billing.PlanFree, billing.SubscriptionStatusNone)
135 if state.LockedAt.Valid || state.LockReason.Valid || state.GraceUntil.Valid {
136 t.Fatalf("free state should clear billing lock: %+v", state)
137 }
138 }
139
140 func TestRecordWebhookEventIsIdempotent(t *testing.T) {
141 _, deps, _ := setup(t)
142 ctx := context.Background()
143
144 event := billing.WebhookEvent{
145 ProviderEventID: "evt_test",
146 EventType: "customer.subscription.updated",
147 APIVersion: "2024-06-20",
148 Payload: []byte(`{"id":"evt_test"}`),
149 }
150 row, created, err := billing.RecordWebhookEvent(ctx, deps, event)
151 if err != nil {
152 t.Fatalf("RecordWebhookEvent first: %v", err)
153 }
154 if !created || row.ProviderEventID != "evt_test" {
155 t.Fatalf("first receipt created=%v row=%+v", created, row)
156 }
157
158 _, created, err = billing.RecordWebhookEvent(ctx, deps, event)
159 if err != nil {
160 t.Fatalf("RecordWebhookEvent duplicate: %v", err)
161 }
162 if created {
163 t.Fatalf("duplicate receipt should not be created")
164 }
165 }
166
167 func TestSyncSeatSnapshotUpdatesBillingState(t *testing.T) {
168 _, deps, org := setup(t)
169 ctx := context.Background()
170
171 snap, err := billing.SyncSeatSnapshot(ctx, deps, billing.SeatSnapshot{
172 OrgID: org.ID,
173 StripeSubscriptionID: "sub_test",
174 ActiveMembers: 2,
175 BillableSeats: 2,
176 })
177 if err != nil {
178 t.Fatalf("SyncSeatSnapshot: %v", err)
179 }
180 if snap.ActiveMembers != 2 || snap.BillableSeats != 2 || snap.Source != "local" {
181 t.Fatalf("unexpected snapshot: %+v", snap)
182 }
183 state, err := billing.GetOrgBillingState(ctx, deps, org.ID)
184 if err != nil {
185 t.Fatalf("GetOrgBillingState: %v", err)
186 }
187 if state.BillableSeats != 2 || !state.SeatSnapshotAt.Valid {
188 t.Fatalf("state did not record seat snapshot: %+v", state)
189 }
190 }
191
192 func assertState(t *testing.T, state billing.State, plan billing.Plan, status billing.SubscriptionStatus) {
193 t.Helper()
194 if state.Plan != plan || state.SubscriptionStatus != status {
195 t.Fatalf("state: want plan=%s status=%s, got plan=%s status=%s", plan, status, state.Plan, state.SubscriptionStatus)
196 }
197 }
198
199 func assertOrgPlan(t *testing.T, pool *pgxpool.Pool, orgID int64, want orgsdb.OrgPlan) {
200 t.Helper()
201 row, err := orgsdb.New().GetOrgByID(context.Background(), pool, orgID)
202 if err != nil {
203 t.Fatalf("GetOrgByID: %v", err)
204 }
205 if row.Plan != want {
206 t.Fatalf("org plan: want %s, got %s", want, row.Plan)
207 }
208 }
209