Go · 15336 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 package billing_test
4
5 import (
6 "context"
7 "io"
8 "log/slog"
9 "testing"
10 "time"
11
12 "github.com/jackc/pgx/v5/pgxpool"
13
14 "github.com/tenseleyFlow/shithub/internal/billing"
15 billingdb "github.com/tenseleyFlow/shithub/internal/billing/sqlc"
16 "github.com/tenseleyFlow/shithub/internal/orgs"
17 orgsdb "github.com/tenseleyFlow/shithub/internal/orgs/sqlc"
18 "github.com/tenseleyFlow/shithub/internal/testing/dbtest"
19 usersdb "github.com/tenseleyFlow/shithub/internal/users/sqlc"
20 )
21
22 const fixtureHash = "$argon2id$v=19$m=16384,t=1,p=1$" +
23 "AAAAAAAAAAAAAAAA$" +
24 "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"
25
26 func setup(t *testing.T) (*pgxpool.Pool, billing.Deps, orgsdb.Org) {
27 t.Helper()
28 pool := dbtest.NewTestDB(t)
29 ctx := context.Background()
30 u, err := usersdb.New().CreateUser(ctx, pool, usersdb.CreateUserParams{
31 Username: "alice", DisplayName: "Alice", PasswordHash: fixtureHash,
32 })
33 if err != nil {
34 t.Fatalf("create user: %v", err)
35 }
36 odeps := orgs.Deps{
37 Pool: pool,
38 Logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
39 }
40 org, err := orgs.Create(ctx, odeps, orgs.CreateParams{
41 Slug: "acme", DisplayName: "Acme Inc", CreatedByUserID: u.ID,
42 })
43 if err != nil {
44 t.Fatalf("create org: %v", err)
45 }
46 return pool, billing.Deps{Pool: pool}, org
47 }
48
49 func TestBillingStateTransitions(t *testing.T) {
50 pool, deps, org := setup(t)
51 ctx := context.Background()
52
53 state, err := billing.GetOrgBillingState(ctx, deps, org.ID)
54 if err != nil {
55 t.Fatalf("GetOrgBillingState: %v", err)
56 }
57 if state.Plan != billing.PlanFree || state.SubscriptionStatus != billing.SubscriptionStatusNone {
58 t.Fatalf("new org state: plan=%s status=%s", state.Plan, state.SubscriptionStatus)
59 }
60
61 state, err = billing.SetStripeCustomer(ctx, deps, org.ID, "cus_test")
62 if err != nil {
63 t.Fatalf("SetStripeCustomer: %v", err)
64 }
65 if !state.StripeCustomerID.Valid || state.StripeCustomerID.String != "cus_test" {
66 t.Fatalf("stripe customer not set: %+v", state.StripeCustomerID)
67 }
68
69 start := time.Now().UTC().Truncate(time.Second)
70 state, err = billing.ApplySubscriptionSnapshot(ctx, deps, billing.SubscriptionSnapshot{
71 OrgID: org.ID,
72 Plan: billing.PlanTeam,
73 Status: billing.SubscriptionStatusActive,
74 StripeSubscriptionID: "sub_test",
75 StripeSubscriptionItemID: "si_test",
76 CurrentPeriodStart: start,
77 CurrentPeriodEnd: start.Add(30 * 24 * time.Hour),
78 LastWebhookEventID: "evt_active",
79 })
80 if err != nil {
81 t.Fatalf("ApplySubscriptionSnapshot active: %v", err)
82 }
83 assertState(t, state, billing.PlanTeam, billing.SubscriptionStatusActive)
84 if state.LockedAt.Valid || state.LockReason.Valid {
85 t.Fatalf("active subscription should not be locked: %+v", state)
86 }
87 assertOrgPlan(t, pool, org.ID, orgsdb.OrgPlanTeam)
88
89 grace := start.Add(7 * 24 * time.Hour)
90 state, err = billing.MarkPastDue(ctx, deps, org.ID, grace, "evt_past_due")
91 if err != nil {
92 t.Fatalf("MarkPastDue: %v", err)
93 }
94 assertState(t, state, billing.PlanTeam, billing.SubscriptionStatusPastDue)
95 if !state.LockedAt.Valid || !state.LockReason.Valid || state.LockReason.BillingLockReason != billingdb.BillingLockReasonPastDue {
96 t.Fatalf("past_due should set lock fields: %+v", state)
97 }
98 if !state.GraceUntil.Valid {
99 t.Fatalf("past_due should set grace_until")
100 }
101
102 state, err = billing.MarkPaymentSucceeded(ctx, deps, org.ID, "evt_paid")
103 if err != nil {
104 t.Fatalf("MarkPaymentSucceeded: %v", err)
105 }
106 assertState(t, state, billing.PlanTeam, billing.SubscriptionStatusActive)
107 if state.LockedAt.Valid || state.LockReason.Valid || state.GraceUntil.Valid || state.PastDueAt.Valid {
108 t.Fatalf("payment recovery should clear lock/grace/past_due: %+v", state)
109 }
110 assertOrgPlan(t, pool, org.ID, orgsdb.OrgPlanTeam)
111
112 state, err = billing.MarkPastDue(ctx, deps, org.ID, grace, "evt_past_due_again")
113 if err != nil {
114 t.Fatalf("MarkPastDue again: %v", err)
115 }
116 assertState(t, state, billing.PlanTeam, billing.SubscriptionStatusPastDue)
117
118 state, err = billing.ApplySubscriptionSnapshot(ctx, deps, billing.SubscriptionSnapshot{
119 OrgID: org.ID,
120 Plan: billing.PlanTeam,
121 Status: billing.SubscriptionStatusActive,
122 StripeSubscriptionID: "sub_test",
123 StripeSubscriptionItemID: "si_test",
124 CurrentPeriodStart: start,
125 CurrentPeriodEnd: start.Add(30 * 24 * time.Hour),
126 LastWebhookEventID: "evt_recovered",
127 })
128 if err != nil {
129 t.Fatalf("ApplySubscriptionSnapshot recovered: %v", err)
130 }
131 assertState(t, state, billing.PlanTeam, billing.SubscriptionStatusActive)
132 if state.LockedAt.Valid || state.LockReason.Valid || state.GraceUntil.Valid || state.PastDueAt.Valid {
133 t.Fatalf("recovered subscription should clear lock/grace/past_due: %+v", state)
134 }
135
136 state, err = billing.MarkCanceled(ctx, deps, org.ID, "evt_canceled")
137 if err != nil {
138 t.Fatalf("MarkCanceled: %v", err)
139 }
140 assertState(t, state, billing.PlanFree, billing.SubscriptionStatusCanceled)
141 if !state.LockedAt.Valid || !state.LockReason.Valid || state.LockReason.BillingLockReason != billingdb.BillingLockReasonCanceled {
142 t.Fatalf("canceled subscription should set canceled lock: %+v", state)
143 }
144 assertOrgPlan(t, pool, org.ID, orgsdb.OrgPlanFree)
145
146 state, err = billing.ClearBillingLock(ctx, deps, org.ID)
147 if err != nil {
148 t.Fatalf("ClearBillingLock: %v", err)
149 }
150 assertState(t, state, billing.PlanFree, billing.SubscriptionStatusNone)
151 if state.LockedAt.Valid || state.LockReason.Valid || state.GraceUntil.Valid {
152 t.Fatalf("free state should clear billing lock: %+v", state)
153 }
154 }
155
156 func TestRecordWebhookEventIsIdempotent(t *testing.T) {
157 _, deps, _ := setup(t)
158 ctx := context.Background()
159
160 event := billing.WebhookEvent{
161 ProviderEventID: "evt_test",
162 EventType: "customer.subscription.updated",
163 APIVersion: "2024-06-20",
164 Payload: []byte(`{"id":"evt_test"}`),
165 }
166 row, created, err := billing.RecordWebhookEvent(ctx, deps, event)
167 if err != nil {
168 t.Fatalf("RecordWebhookEvent first: %v", err)
169 }
170 if !created || row.ProviderEventID != "evt_test" {
171 t.Fatalf("first receipt created=%v row=%+v", created, row)
172 }
173
174 dup, created, err := billing.RecordWebhookEvent(ctx, deps, event)
175 if err != nil {
176 t.Fatalf("RecordWebhookEvent duplicate: %v", err)
177 }
178 if created {
179 t.Fatalf("duplicate receipt should not be created")
180 }
181 if dup.ID != row.ID || dup.ProcessedAt.Valid {
182 t.Fatalf("duplicate should return existing unprocessed receipt: first=%+v dup=%+v", row, dup)
183 }
184
185 if _, err := billing.MarkWebhookEventProcessed(ctx, deps, event.ProviderEventID); err != nil {
186 t.Fatalf("MarkWebhookEventProcessed: %v", err)
187 }
188 receipt, err := billing.GetWebhookEventReceipt(ctx, deps, event.ProviderEventID)
189 if err != nil {
190 t.Fatalf("GetWebhookEventReceipt: %v", err)
191 }
192 if receipt.ProviderEventID != event.ProviderEventID || !receipt.ProcessedAt.Valid {
193 t.Fatalf("unexpected receipt lookup: %+v", receipt)
194 }
195 dup, created, err = billing.RecordWebhookEvent(ctx, deps, event)
196 if err != nil {
197 t.Fatalf("RecordWebhookEvent after processed: %v", err)
198 }
199 if created {
200 t.Fatalf("processed duplicate should not be created")
201 }
202 if dup.ID != row.ID || !dup.ProcessedAt.Valid {
203 t.Fatalf("processed duplicate should return existing processed receipt: first=%+v dup=%+v", row, dup)
204 }
205 if _, err := billing.MarkWebhookEventFailed(ctx, deps, event.ProviderEventID, "late duplicate"); err != nil {
206 t.Fatalf("MarkWebhookEventFailed: %v", err)
207 }
208 }
209
210 // TestSetWebhookEventSubjectForPrincipalRecordsAuditTrail locks PRO08
211 // A2: after a successful resolve in the webhook apply path, the
212 // receipt row carries (subject_kind, subject_id) so the operator
213 // query for "which subject did this event apply to" works.
214 func TestSetWebhookEventSubjectForPrincipalRecordsAuditTrail(t *testing.T) {
215 _, deps, org := setup(t)
216 ctx := context.Background()
217 if _, _, err := billing.RecordWebhookEvent(ctx, deps, billing.WebhookEvent{
218 ProviderEventID: "evt_subject_record",
219 EventType: "customer.subscription.updated",
220 APIVersion: "2024-06-20",
221 Payload: []byte(`{"id":"evt_subject_record"}`),
222 }); err != nil {
223 t.Fatalf("RecordWebhookEvent: %v", err)
224 }
225 if err := billing.SetWebhookEventSubjectForPrincipal(ctx, deps, "evt_subject_record", billing.PrincipalForOrg(org.ID)); err != nil {
226 t.Fatalf("SetWebhookEventSubjectForPrincipal: %v", err)
227 }
228 receipt, err := billing.GetWebhookEventReceipt(ctx, deps, "evt_subject_record")
229 if err != nil {
230 t.Fatalf("GetWebhookEventReceipt: %v", err)
231 }
232 if !receipt.SubjectKind.Valid || receipt.SubjectKind.BillingSubjectKind != billingdb.BillingSubjectKindOrg {
233 t.Fatalf("subject_kind: got %+v, want org", receipt.SubjectKind)
234 }
235 if !receipt.SubjectID.Valid || receipt.SubjectID.Int64 != org.ID {
236 t.Fatalf("subject_id: got %+v, want %d", receipt.SubjectID, org.ID)
237 }
238 }
239
240 // TestListFailedWebhookEventsReturnsErroredAndStuckEntries locks the
241 // operator inspection query used in the runbook. A receipt with a
242 // non-empty process_error or with processing_attempts > 0 but no
243 // processed_at must appear; a clean unprocessed row (attempts=0)
244 // must NOT appear.
245 func TestListFailedWebhookEventsReturnsErroredAndStuckEntries(t *testing.T) {
246 _, deps, _ := setup(t)
247 ctx := context.Background()
248
249 // 1. failed row (process_error non-empty)
250 if _, _, err := billing.RecordWebhookEvent(ctx, deps, billing.WebhookEvent{
251 ProviderEventID: "evt_failed", EventType: "x", Payload: []byte(`{}`),
252 }); err != nil {
253 t.Fatalf("record failed: %v", err)
254 }
255 if _, err := billing.MarkWebhookEventFailed(ctx, deps, "evt_failed", "boom"); err != nil {
256 t.Fatalf("mark failed: %v", err)
257 }
258
259 // 2. clean processed row — must NOT appear.
260 if _, _, err := billing.RecordWebhookEvent(ctx, deps, billing.WebhookEvent{
261 ProviderEventID: "evt_clean", EventType: "x", Payload: []byte(`{}`),
262 }); err != nil {
263 t.Fatalf("record clean: %v", err)
264 }
265 if _, err := billing.MarkWebhookEventProcessed(ctx, deps, "evt_clean"); err != nil {
266 t.Fatalf("mark clean: %v", err)
267 }
268
269 // 3. brand-new untouched row — must NOT appear.
270 if _, _, err := billing.RecordWebhookEvent(ctx, deps, billing.WebhookEvent{
271 ProviderEventID: "evt_new", EventType: "x", Payload: []byte(`{}`),
272 }); err != nil {
273 t.Fatalf("record new: %v", err)
274 }
275
276 rows, err := billing.ListFailedWebhookEvents(ctx, deps, 50)
277 if err != nil {
278 t.Fatalf("ListFailedWebhookEvents: %v", err)
279 }
280 got := map[string]bool{}
281 for _, r := range rows {
282 got[r.ProviderEventID] = true
283 }
284 if !got["evt_failed"] {
285 t.Errorf("expected evt_failed in failed list, got %v", got)
286 }
287 if got["evt_clean"] {
288 t.Errorf("evt_clean (processed, no error) leaked into failed list: %v", got)
289 }
290 if got["evt_new"] {
291 t.Errorf("evt_new (untouched) leaked into failed list: %v", got)
292 }
293 }
294
295 func TestSyncSeatSnapshotUpdatesBillingState(t *testing.T) {
296 _, deps, org := setup(t)
297 ctx := context.Background()
298
299 snap, err := billing.SyncSeatSnapshot(ctx, deps, billing.SeatSnapshot{
300 OrgID: org.ID,
301 StripeSubscriptionID: "sub_test",
302 ActiveMembers: 2,
303 BillableSeats: 2,
304 })
305 if err != nil {
306 t.Fatalf("SyncSeatSnapshot: %v", err)
307 }
308 if snap.ActiveMembers != 2 || snap.BillableSeats != 2 || snap.Source != "local" {
309 t.Fatalf("unexpected snapshot: %+v", snap)
310 }
311 state, err := billing.GetOrgBillingState(ctx, deps, org.ID)
312 if err != nil {
313 t.Fatalf("GetOrgBillingState: %v", err)
314 }
315 if state.BillableSeats != 2 || !state.SeatSnapshotAt.Valid {
316 t.Fatalf("state did not record seat snapshot: %+v", state)
317 }
318
319 count, err := billing.CountBillableOrgMembers(ctx, deps, org.ID)
320 if err != nil {
321 t.Fatalf("CountBillableOrgMembers: %v", err)
322 }
323 if count != 1 {
324 t.Fatalf("billable members: got %d, want 1", count)
325 }
326
327 if _, err := deps.Pool.Exec(ctx, `
328 INSERT INTO org_invitations (org_id, target_email, role, token_hash, expires_at)
329 VALUES ($1, 'pending@example.com', 'member', '\x010203', now() + interval '1 day'),
330 ($1, 'expired@example.com', 'member', '\x040506', now() - interval '1 day')
331 `, org.ID); err != nil {
332 t.Fatalf("insert invitations: %v", err)
333 }
334 pending, err := billing.CountPendingOrgInvitations(ctx, deps, org.ID)
335 if err != nil {
336 t.Fatalf("CountPendingOrgInvitations: %v", err)
337 }
338 if pending != 1 {
339 t.Fatalf("pending invitations: got %d, want 1", pending)
340 }
341 }
342
343 func TestStripeLookupsAndInvoiceSnapshot(t *testing.T) {
344 _, deps, org := setup(t)
345 ctx := context.Background()
346
347 start := time.Now().UTC().Truncate(time.Second)
348 if _, err := billing.SetStripeCustomer(ctx, deps, org.ID, "cus_lookup"); err != nil {
349 t.Fatalf("SetStripeCustomer: %v", err)
350 }
351 if _, err := billing.ApplySubscriptionSnapshot(ctx, deps, billing.SubscriptionSnapshot{
352 OrgID: org.ID,
353 Plan: billing.PlanTeam,
354 Status: billing.SubscriptionStatusActive,
355 StripeSubscriptionID: "sub_lookup",
356 StripeSubscriptionItemID: "si_lookup",
357 CurrentPeriodStart: start,
358 CurrentPeriodEnd: start.Add(30 * 24 * time.Hour),
359 LastWebhookEventID: "evt_lookup",
360 }); err != nil {
361 t.Fatalf("ApplySubscriptionSnapshot: %v", err)
362 }
363
364 byCustomer, err := billing.GetOrgBillingStateByStripeCustomer(ctx, deps, "cus_lookup")
365 if err != nil {
366 t.Fatalf("GetOrgBillingStateByStripeCustomer: %v", err)
367 }
368 if byCustomer.OrgID != org.ID {
369 t.Fatalf("customer lookup org_id: got %d, want %d", byCustomer.OrgID, org.ID)
370 }
371 bySubscription, err := billing.GetOrgBillingStateByStripeSubscription(ctx, deps, "sub_lookup")
372 if err != nil {
373 t.Fatalf("GetOrgBillingStateByStripeSubscription: %v", err)
374 }
375 if bySubscription.OrgID != org.ID {
376 t.Fatalf("subscription lookup org_id: got %d, want %d", bySubscription.OrgID, org.ID)
377 }
378
379 invoice, err := billing.UpsertInvoice(ctx, deps, billing.InvoiceSnapshot{
380 OrgID: org.ID,
381 StripeInvoiceID: "in_lookup",
382 StripeCustomerID: "cus_lookup",
383 StripeSubscriptionID: "sub_lookup",
384 Status: billing.InvoiceStatusPaid,
385 Number: "SHI-0001",
386 Currency: "USD",
387 AmountDueCents: 1200,
388 AmountPaidCents: 1200,
389 AmountRemainingCents: 0,
390 HostedInvoiceURL: "https://invoice.stripe.test/i",
391 InvoicePDFURL: "https://invoice.stripe.test/i.pdf",
392 PeriodStart: start,
393 PeriodEnd: start.Add(30 * 24 * time.Hour),
394 PaidAt: start.Add(time.Minute),
395 })
396 if err != nil {
397 t.Fatalf("UpsertInvoice: %v", err)
398 }
399 if invoice.StripeInvoiceID != "in_lookup" || invoice.Status != billing.InvoiceStatusPaid || invoice.Currency != "usd" {
400 t.Fatalf("unexpected invoice: %+v", invoice)
401 }
402 }
403
404 func assertState(t *testing.T, state billing.State, plan billing.Plan, status billing.SubscriptionStatus) {
405 t.Helper()
406 if state.Plan != plan || state.SubscriptionStatus != status {
407 t.Fatalf("state: want plan=%s status=%s, got plan=%s status=%s", plan, status, state.Plan, state.SubscriptionStatus)
408 }
409 }
410
411 func assertOrgPlan(t *testing.T, pool *pgxpool.Pool, orgID int64, want orgsdb.OrgPlan) {
412 t.Helper()
413 row, err := orgsdb.New().GetOrgByID(context.Background(), pool, orgID)
414 if err != nil {
415 t.Fatalf("GetOrgByID: %v", err)
416 }
417 if row.Plan != want {
418 t.Fatalf("org plan: want %s, got %s", want, row.Plan)
419 }
420 }
421