tenseleyflow/shithub / 25f49fa

Browse files

Add Stripe billing adapter

Authored by mfwolffe <wolffemf@dukes.jmu.edu>
SHA
25f49fa40e630f05fc2c441f3001d5558a1a4aea
Parents
8f753ab
Tree
366a162

7 changed files

StatusFile+-
M internal/billing/billing.go 161 10
M internal/billing/billing_test.go 76 0
M internal/billing/queries/billing.sql 15 0
M internal/billing/sqlc/billing.sql.go 83 0
M internal/billing/sqlc/querier.go 5 0
A internal/billing/stripebilling/client.go 244 0
A internal/billing/stripebilling/client_test.go 52 0
internal/billing/billing.gomodified
@@ -2,7 +2,7 @@
22
 
33
 // Package billing owns local paid-organization state. It stores Stripe
44
 // identifiers and derived subscription state, but it does not call
5
-// Stripe directly; webhook/API integration lands in SP03.
5
+// Stripe directly; Stripe API details stay in the SP03 adapter layer.
66
 package billing
77
 
88
 import (
@@ -27,6 +27,7 @@ type Deps struct {
2727
 type (
2828
 	Plan               = billingdb.OrgPlan
2929
 	SubscriptionStatus = billingdb.BillingSubscriptionStatus
30
+	InvoiceStatus      = billingdb.BillingInvoiceStatus
3031
 	State              = billingdb.OrgBillingState
3132
 )
3233
 
@@ -43,18 +44,27 @@ const (
4344
 	SubscriptionStatusCanceled   = billingdb.BillingSubscriptionStatusCanceled
4445
 	SubscriptionStatusUnpaid     = billingdb.BillingSubscriptionStatusUnpaid
4546
 	SubscriptionStatusPaused     = billingdb.BillingSubscriptionStatusPaused
47
+
48
+	InvoiceStatusDraft         = billingdb.BillingInvoiceStatusDraft
49
+	InvoiceStatusOpen          = billingdb.BillingInvoiceStatusOpen
50
+	InvoiceStatusPaid          = billingdb.BillingInvoiceStatusPaid
51
+	InvoiceStatusVoid          = billingdb.BillingInvoiceStatusVoid
52
+	InvoiceStatusUncollectible = billingdb.BillingInvoiceStatusUncollectible
4653
 )
4754
 
4855
 var (
49
-	ErrPoolRequired     = errors.New("billing: pool is required")
50
-	ErrOrgIDRequired    = errors.New("billing: org id is required")
51
-	ErrStripeCustomerID = errors.New("billing: stripe customer id is required")
52
-	ErrInvalidPlan      = errors.New("billing: invalid plan")
53
-	ErrInvalidStatus    = errors.New("billing: invalid subscription status")
54
-	ErrInvalidSeatCount = errors.New("billing: seat counts cannot be negative")
55
-	ErrWebhookEventID   = errors.New("billing: webhook event id is required")
56
-	ErrWebhookEventType = errors.New("billing: webhook event type is required")
57
-	ErrWebhookPayload   = errors.New("billing: webhook payload must be a JSON object")
56
+	ErrPoolRequired         = errors.New("billing: pool is required")
57
+	ErrOrgIDRequired        = errors.New("billing: org id is required")
58
+	ErrStripeCustomerID     = errors.New("billing: stripe customer id is required")
59
+	ErrStripeSubscriptionID = errors.New("billing: stripe subscription id is required")
60
+	ErrStripeInvoiceID      = errors.New("billing: stripe invoice id is required")
61
+	ErrInvalidPlan          = errors.New("billing: invalid plan")
62
+	ErrInvalidStatus        = errors.New("billing: invalid subscription status")
63
+	ErrInvalidInvoiceStatus = errors.New("billing: invalid invoice status")
64
+	ErrInvalidSeatCount     = errors.New("billing: seat counts cannot be negative")
65
+	ErrWebhookEventID       = errors.New("billing: webhook event id is required")
66
+	ErrWebhookEventType     = errors.New("billing: webhook event type is required")
67
+	ErrWebhookPayload       = errors.New("billing: webhook payload must be a JSON object")
5868
 )
5969
 
6070
 // SubscriptionSnapshot is the local projection of a provider
@@ -88,6 +98,26 @@ type WebhookEvent struct {
8898
 	Payload         []byte
8999
 }
90100
 
101
+type InvoiceSnapshot struct {
102
+	OrgID                int64
103
+	StripeInvoiceID      string
104
+	StripeCustomerID     string
105
+	StripeSubscriptionID string
106
+	Status               InvoiceStatus
107
+	Number               string
108
+	Currency             string
109
+	AmountDueCents       int64
110
+	AmountPaidCents      int64
111
+	AmountRemainingCents int64
112
+	HostedInvoiceURL     string
113
+	InvoicePDFURL        string
114
+	PeriodStart          time.Time
115
+	PeriodEnd            time.Time
116
+	DueAt                time.Time
117
+	PaidAt               time.Time
118
+	VoidedAt             time.Time
119
+}
120
+
91121
 func GetOrgBillingState(ctx context.Context, deps Deps, orgID int64) (State, error) {
92122
 	if err := validateDeps(deps); err != nil {
93123
 		return State{}, err
@@ -98,6 +128,28 @@ func GetOrgBillingState(ctx context.Context, deps Deps, orgID int64) (State, err
98128
 	return billingdb.New().GetOrgBillingState(ctx, deps.Pool, orgID)
99129
 }
100130
 
131
+func GetOrgBillingStateByStripeCustomer(ctx context.Context, deps Deps, customerID string) (State, error) {
132
+	if err := validateDeps(deps); err != nil {
133
+		return State{}, err
134
+	}
135
+	customerID = strings.TrimSpace(customerID)
136
+	if customerID == "" {
137
+		return State{}, ErrStripeCustomerID
138
+	}
139
+	return billingdb.New().GetOrgBillingStateByStripeCustomer(ctx, deps.Pool, pgText(customerID))
140
+}
141
+
142
+func GetOrgBillingStateByStripeSubscription(ctx context.Context, deps Deps, subscriptionID string) (State, error) {
143
+	if err := validateDeps(deps); err != nil {
144
+		return State{}, err
145
+	}
146
+	subscriptionID = strings.TrimSpace(subscriptionID)
147
+	if subscriptionID == "" {
148
+		return State{}, ErrStripeSubscriptionID
149
+	}
150
+	return billingdb.New().GetOrgBillingStateByStripeSubscription(ctx, deps.Pool, pgText(subscriptionID))
151
+}
152
+
101153
 func SetStripeCustomer(ctx context.Context, deps Deps, orgID int64, customerID string) (State, error) {
102154
 	if err := validateDeps(deps); err != nil {
103155
 		return State{}, err
@@ -178,6 +230,78 @@ func RecordWebhookEvent(ctx context.Context, deps Deps, event WebhookEvent) (bil
178230
 	return row, true, nil
179231
 }
180232
 
233
+func MarkWebhookEventProcessed(ctx context.Context, deps Deps, providerEventID string) (billingdb.BillingWebhookEvent, error) {
234
+	if err := validateDeps(deps); err != nil {
235
+		return billingdb.BillingWebhookEvent{}, err
236
+	}
237
+	providerEventID = strings.TrimSpace(providerEventID)
238
+	if providerEventID == "" {
239
+		return billingdb.BillingWebhookEvent{}, ErrWebhookEventID
240
+	}
241
+	return billingdb.New().MarkWebhookEventProcessed(ctx, deps.Pool, providerEventID)
242
+}
243
+
244
+func MarkWebhookEventFailed(ctx context.Context, deps Deps, providerEventID, processError string) (billingdb.BillingWebhookEvent, error) {
245
+	if err := validateDeps(deps); err != nil {
246
+		return billingdb.BillingWebhookEvent{}, err
247
+	}
248
+	providerEventID = strings.TrimSpace(providerEventID)
249
+	if providerEventID == "" {
250
+		return billingdb.BillingWebhookEvent{}, ErrWebhookEventID
251
+	}
252
+	processError = strings.TrimSpace(processError)
253
+	if len(processError) > 2000 {
254
+		processError = processError[:2000]
255
+	}
256
+	return billingdb.New().MarkWebhookEventFailed(ctx, deps.Pool, billingdb.MarkWebhookEventFailedParams{
257
+		ProviderEventID: providerEventID,
258
+		ProcessError:    processError,
259
+	})
260
+}
261
+
262
+func UpsertInvoice(ctx context.Context, deps Deps, snap InvoiceSnapshot) (billingdb.BillingInvoice, error) {
263
+	if err := validateDeps(deps); err != nil {
264
+		return billingdb.BillingInvoice{}, err
265
+	}
266
+	if snap.OrgID == 0 {
267
+		return billingdb.BillingInvoice{}, ErrOrgIDRequired
268
+	}
269
+	snap.StripeInvoiceID = strings.TrimSpace(snap.StripeInvoiceID)
270
+	if snap.StripeInvoiceID == "" {
271
+		return billingdb.BillingInvoice{}, ErrStripeInvoiceID
272
+	}
273
+	snap.StripeCustomerID = strings.TrimSpace(snap.StripeCustomerID)
274
+	if snap.StripeCustomerID == "" {
275
+		return billingdb.BillingInvoice{}, ErrStripeCustomerID
276
+	}
277
+	if !validInvoiceStatus(snap.Status) {
278
+		return billingdb.BillingInvoice{}, fmt.Errorf("%w: %q", ErrInvalidInvoiceStatus, snap.Status)
279
+	}
280
+	row, err := billingdb.New().UpsertInvoice(ctx, deps.Pool, billingdb.UpsertInvoiceParams{
281
+		OrgID:                snap.OrgID,
282
+		StripeInvoiceID:      snap.StripeInvoiceID,
283
+		StripeCustomerID:     snap.StripeCustomerID,
284
+		StripeSubscriptionID: pgText(snap.StripeSubscriptionID),
285
+		Status:               snap.Status,
286
+		Number:               strings.TrimSpace(snap.Number),
287
+		Currency:             strings.ToLower(strings.TrimSpace(snap.Currency)),
288
+		AmountDueCents:       snap.AmountDueCents,
289
+		AmountPaidCents:      snap.AmountPaidCents,
290
+		AmountRemainingCents: snap.AmountRemainingCents,
291
+		HostedInvoiceUrl:     strings.TrimSpace(snap.HostedInvoiceURL),
292
+		InvoicePdfUrl:        strings.TrimSpace(snap.InvoicePDFURL),
293
+		PeriodStart:          pgTime(snap.PeriodStart),
294
+		PeriodEnd:            pgTime(snap.PeriodEnd),
295
+		DueAt:                pgTime(snap.DueAt),
296
+		PaidAt:               pgTime(snap.PaidAt),
297
+		VoidedAt:             pgTime(snap.VoidedAt),
298
+	})
299
+	if err != nil {
300
+		return billingdb.BillingInvoice{}, err
301
+	}
302
+	return row, nil
303
+}
304
+
181305
 func SyncSeatSnapshot(ctx context.Context, deps Deps, snap SeatSnapshot) (billingdb.BillingSeatSnapshot, error) {
182306
 	if err := validateDeps(deps); err != nil {
183307
 		return billingdb.BillingSeatSnapshot{}, err
@@ -205,6 +329,20 @@ func SyncSeatSnapshot(ctx context.Context, deps Deps, snap SeatSnapshot) (billin
205329
 	return billingdb.BillingSeatSnapshot(row), nil
206330
 }
207331
 
332
+func CountBillableOrgMembers(ctx context.Context, deps Deps, orgID int64) (int, error) {
333
+	if err := validateDeps(deps); err != nil {
334
+		return 0, err
335
+	}
336
+	if orgID == 0 {
337
+		return 0, ErrOrgIDRequired
338
+	}
339
+	n, err := billingdb.New().CountBillableOrgMembers(ctx, deps.Pool, orgID)
340
+	if err != nil {
341
+		return 0, err
342
+	}
343
+	return int(n), nil
344
+}
345
+
208346
 func MarkPastDue(ctx context.Context, deps Deps, orgID int64, graceUntil time.Time, lastWebhookEventID string) (State, error) {
209347
 	if err := validateDeps(deps); err != nil {
210348
 		return State{}, err
@@ -282,6 +420,19 @@ func validStatus(status SubscriptionStatus) bool {
282420
 	}
283421
 }
284422
 
423
+func validInvoiceStatus(status InvoiceStatus) bool {
424
+	switch status {
425
+	case InvoiceStatusDraft,
426
+		InvoiceStatusOpen,
427
+		InvoiceStatusPaid,
428
+		InvoiceStatusVoid,
429
+		InvoiceStatusUncollectible:
430
+		return true
431
+	default:
432
+		return false
433
+	}
434
+}
435
+
285436
 func pgText(s string) pgtype.Text {
286437
 	s = strings.TrimSpace(s)
287438
 	return pgtype.Text{String: s, Valid: s != ""}
internal/billing/billing_test.gomodified
@@ -162,6 +162,13 @@ func TestRecordWebhookEventIsIdempotent(t *testing.T) {
162162
 	if created {
163163
 		t.Fatalf("duplicate receipt should not be created")
164164
 	}
165
+
166
+	if _, err := billing.MarkWebhookEventProcessed(ctx, deps, event.ProviderEventID); err != nil {
167
+		t.Fatalf("MarkWebhookEventProcessed: %v", err)
168
+	}
169
+	if _, err := billing.MarkWebhookEventFailed(ctx, deps, event.ProviderEventID, "late duplicate"); err != nil {
170
+		t.Fatalf("MarkWebhookEventFailed: %v", err)
171
+	}
165172
 }
166173
 
167174
 func TestSyncSeatSnapshotUpdatesBillingState(t *testing.T) {
@@ -187,6 +194,75 @@ func TestSyncSeatSnapshotUpdatesBillingState(t *testing.T) {
187194
 	if state.BillableSeats != 2 || !state.SeatSnapshotAt.Valid {
188195
 		t.Fatalf("state did not record seat snapshot: %+v", state)
189196
 	}
197
+
198
+	count, err := billing.CountBillableOrgMembers(ctx, deps, org.ID)
199
+	if err != nil {
200
+		t.Fatalf("CountBillableOrgMembers: %v", err)
201
+	}
202
+	if count != 1 {
203
+		t.Fatalf("billable members: got %d, want 1", count)
204
+	}
205
+}
206
+
207
+func TestStripeLookupsAndInvoiceSnapshot(t *testing.T) {
208
+	_, deps, org := setup(t)
209
+	ctx := context.Background()
210
+
211
+	start := time.Now().UTC().Truncate(time.Second)
212
+	if _, err := billing.SetStripeCustomer(ctx, deps, org.ID, "cus_lookup"); err != nil {
213
+		t.Fatalf("SetStripeCustomer: %v", err)
214
+	}
215
+	if _, err := billing.ApplySubscriptionSnapshot(ctx, deps, billing.SubscriptionSnapshot{
216
+		OrgID:                    org.ID,
217
+		Plan:                     billing.PlanTeam,
218
+		Status:                   billing.SubscriptionStatusActive,
219
+		StripeSubscriptionID:     "sub_lookup",
220
+		StripeSubscriptionItemID: "si_lookup",
221
+		CurrentPeriodStart:       start,
222
+		CurrentPeriodEnd:         start.Add(30 * 24 * time.Hour),
223
+		LastWebhookEventID:       "evt_lookup",
224
+	}); err != nil {
225
+		t.Fatalf("ApplySubscriptionSnapshot: %v", err)
226
+	}
227
+
228
+	byCustomer, err := billing.GetOrgBillingStateByStripeCustomer(ctx, deps, "cus_lookup")
229
+	if err != nil {
230
+		t.Fatalf("GetOrgBillingStateByStripeCustomer: %v", err)
231
+	}
232
+	if byCustomer.OrgID != org.ID {
233
+		t.Fatalf("customer lookup org_id: got %d, want %d", byCustomer.OrgID, org.ID)
234
+	}
235
+	bySubscription, err := billing.GetOrgBillingStateByStripeSubscription(ctx, deps, "sub_lookup")
236
+	if err != nil {
237
+		t.Fatalf("GetOrgBillingStateByStripeSubscription: %v", err)
238
+	}
239
+	if bySubscription.OrgID != org.ID {
240
+		t.Fatalf("subscription lookup org_id: got %d, want %d", bySubscription.OrgID, org.ID)
241
+	}
242
+
243
+	invoice, err := billing.UpsertInvoice(ctx, deps, billing.InvoiceSnapshot{
244
+		OrgID:                org.ID,
245
+		StripeInvoiceID:      "in_lookup",
246
+		StripeCustomerID:     "cus_lookup",
247
+		StripeSubscriptionID: "sub_lookup",
248
+		Status:               billing.InvoiceStatusPaid,
249
+		Number:               "SHI-0001",
250
+		Currency:             "USD",
251
+		AmountDueCents:       1200,
252
+		AmountPaidCents:      1200,
253
+		AmountRemainingCents: 0,
254
+		HostedInvoiceURL:     "https://invoice.stripe.test/i",
255
+		InvoicePDFURL:        "https://invoice.stripe.test/i.pdf",
256
+		PeriodStart:          start,
257
+		PeriodEnd:            start.Add(30 * 24 * time.Hour),
258
+		PaidAt:               start.Add(time.Minute),
259
+	})
260
+	if err != nil {
261
+		t.Fatalf("UpsertInvoice: %v", err)
262
+	}
263
+	if invoice.StripeInvoiceID != "in_lookup" || invoice.Status != billing.InvoiceStatusPaid || invoice.Currency != "usd" {
264
+		t.Fatalf("unexpected invoice: %+v", invoice)
265
+	}
190266
 }
191267
 
192268
 func assertState(t *testing.T, state billing.State, plan billing.Plan, status billing.SubscriptionStatus) {
internal/billing/queries/billing.sqlmodified
@@ -5,6 +5,16 @@
55
 -- name: GetOrgBillingState :one
66
 SELECT * FROM org_billing_states WHERE org_id = $1;
77
 
8
+-- name: GetOrgBillingStateByStripeCustomer :one
9
+SELECT * FROM org_billing_states
10
+WHERE provider = 'stripe'
11
+  AND stripe_customer_id = $1;
12
+
13
+-- name: GetOrgBillingStateByStripeSubscription :one
14
+SELECT * FROM org_billing_states
15
+WHERE provider = 'stripe'
16
+  AND stripe_subscription_id = $1;
17
+
818
 -- name: SetStripeCustomer :one
919
 INSERT INTO org_billing_states (org_id, provider, stripe_customer_id)
1020
 VALUES ($1, 'stripe', $2)
@@ -184,6 +194,11 @@ WHERE org_id = $1
184194
 ORDER BY captured_at DESC, id DESC
185195
 LIMIT $2;
186196
 
197
+-- name: CountBillableOrgMembers :one
198
+SELECT count(*)::integer
199
+FROM org_members
200
+WHERE org_id = $1;
201
+
187202
 -- ─── billing_invoices ──────────────────────────────────────────────
188203
 
189204
 -- name: UpsertInvoice :one
internal/billing/sqlc/billing.sql.gomodified
@@ -242,6 +242,19 @@ func (q *Queries) ClearBillingLock(ctx context.Context, db DBTX, orgID int64) (C
242242
 	return i, err
243243
 }
244244
 
245
+const countBillableOrgMembers = `-- name: CountBillableOrgMembers :one
246
+SELECT count(*)::integer
247
+FROM org_members
248
+WHERE org_id = $1
249
+`
250
+
251
+func (q *Queries) CountBillableOrgMembers(ctx context.Context, db DBTX, orgID int64) (int32, error) {
252
+	row := db.QueryRow(ctx, countBillableOrgMembers, orgID)
253
+	var column_1 int32
254
+	err := row.Scan(&column_1)
255
+	return column_1, err
256
+}
257
+
245258
 const createSeatSnapshot = `-- name: CreateSeatSnapshot :one
246259
 
247260
 WITH snapshot AS (
@@ -404,6 +417,76 @@ func (q *Queries) GetOrgBillingState(ctx context.Context, db DBTX, orgID int64)
404417
 	return i, err
405418
 }
406419
 
420
+const getOrgBillingStateByStripeCustomer = `-- name: GetOrgBillingStateByStripeCustomer :one
421
+SELECT org_id, provider, stripe_customer_id, stripe_subscription_id, stripe_subscription_item_id, plan, subscription_status, billable_seats, seat_snapshot_at, current_period_start, current_period_end, cancel_at_period_end, trial_end, past_due_at, canceled_at, locked_at, lock_reason, grace_until, last_webhook_event_id, created_at, updated_at FROM org_billing_states
422
+WHERE provider = 'stripe'
423
+  AND stripe_customer_id = $1
424
+`
425
+
426
+func (q *Queries) GetOrgBillingStateByStripeCustomer(ctx context.Context, db DBTX, stripeCustomerID pgtype.Text) (OrgBillingState, error) {
427
+	row := db.QueryRow(ctx, getOrgBillingStateByStripeCustomer, stripeCustomerID)
428
+	var i OrgBillingState
429
+	err := row.Scan(
430
+		&i.OrgID,
431
+		&i.Provider,
432
+		&i.StripeCustomerID,
433
+		&i.StripeSubscriptionID,
434
+		&i.StripeSubscriptionItemID,
435
+		&i.Plan,
436
+		&i.SubscriptionStatus,
437
+		&i.BillableSeats,
438
+		&i.SeatSnapshotAt,
439
+		&i.CurrentPeriodStart,
440
+		&i.CurrentPeriodEnd,
441
+		&i.CancelAtPeriodEnd,
442
+		&i.TrialEnd,
443
+		&i.PastDueAt,
444
+		&i.CanceledAt,
445
+		&i.LockedAt,
446
+		&i.LockReason,
447
+		&i.GraceUntil,
448
+		&i.LastWebhookEventID,
449
+		&i.CreatedAt,
450
+		&i.UpdatedAt,
451
+	)
452
+	return i, err
453
+}
454
+
455
+const getOrgBillingStateByStripeSubscription = `-- name: GetOrgBillingStateByStripeSubscription :one
456
+SELECT org_id, provider, stripe_customer_id, stripe_subscription_id, stripe_subscription_item_id, plan, subscription_status, billable_seats, seat_snapshot_at, current_period_start, current_period_end, cancel_at_period_end, trial_end, past_due_at, canceled_at, locked_at, lock_reason, grace_until, last_webhook_event_id, created_at, updated_at FROM org_billing_states
457
+WHERE provider = 'stripe'
458
+  AND stripe_subscription_id = $1
459
+`
460
+
461
+func (q *Queries) GetOrgBillingStateByStripeSubscription(ctx context.Context, db DBTX, stripeSubscriptionID pgtype.Text) (OrgBillingState, error) {
462
+	row := db.QueryRow(ctx, getOrgBillingStateByStripeSubscription, stripeSubscriptionID)
463
+	var i OrgBillingState
464
+	err := row.Scan(
465
+		&i.OrgID,
466
+		&i.Provider,
467
+		&i.StripeCustomerID,
468
+		&i.StripeSubscriptionID,
469
+		&i.StripeSubscriptionItemID,
470
+		&i.Plan,
471
+		&i.SubscriptionStatus,
472
+		&i.BillableSeats,
473
+		&i.SeatSnapshotAt,
474
+		&i.CurrentPeriodStart,
475
+		&i.CurrentPeriodEnd,
476
+		&i.CancelAtPeriodEnd,
477
+		&i.TrialEnd,
478
+		&i.PastDueAt,
479
+		&i.CanceledAt,
480
+		&i.LockedAt,
481
+		&i.LockReason,
482
+		&i.GraceUntil,
483
+		&i.LastWebhookEventID,
484
+		&i.CreatedAt,
485
+		&i.UpdatedAt,
486
+	)
487
+	return i, err
488
+}
489
+
407490
 const listInvoicesForOrg = `-- name: ListInvoicesForOrg :many
408491
 SELECT id, org_id, provider, stripe_invoice_id, stripe_customer_id, stripe_subscription_id, status, number, currency, amount_due_cents, amount_paid_cents, amount_remaining_cents, hosted_invoice_url, invoice_pdf_url, period_start, period_end, due_at, paid_at, voided_at, created_at, updated_at FROM billing_invoices
409492
 WHERE org_id = $1
internal/billing/sqlc/querier.gomodified
@@ -6,11 +6,14 @@ package billingdb
66
 
77
 import (
88
 	"context"
9
+
10
+	"github.com/jackc/pgx/v5/pgtype"
911
 )
1012
 
1113
 type Querier interface {
1214
 	ApplySubscriptionSnapshot(ctx context.Context, db DBTX, arg ApplySubscriptionSnapshotParams) (ApplySubscriptionSnapshotRow, error)
1315
 	ClearBillingLock(ctx context.Context, db DBTX, orgID int64) (ClearBillingLockRow, error)
16
+	CountBillableOrgMembers(ctx context.Context, db DBTX, orgID int64) (int32, error)
1417
 	// ─── billing_seat_snapshots ────────────────────────────────────────
1518
 	CreateSeatSnapshot(ctx context.Context, db DBTX, arg CreateSeatSnapshotParams) (CreateSeatSnapshotRow, error)
1619
 	// ─── billing_webhook_events ────────────────────────────────────────
@@ -18,6 +21,8 @@ type Querier interface {
1821
 	// SPDX-License-Identifier: AGPL-3.0-or-later
1922
 	// ─── org_billing_states ────────────────────────────────────────────
2023
 	GetOrgBillingState(ctx context.Context, db DBTX, orgID int64) (OrgBillingState, error)
24
+	GetOrgBillingStateByStripeCustomer(ctx context.Context, db DBTX, stripeCustomerID pgtype.Text) (OrgBillingState, error)
25
+	GetOrgBillingStateByStripeSubscription(ctx context.Context, db DBTX, stripeSubscriptionID pgtype.Text) (OrgBillingState, error)
2126
 	ListInvoicesForOrg(ctx context.Context, db DBTX, arg ListInvoicesForOrgParams) ([]BillingInvoice, error)
2227
 	ListSeatSnapshotsForOrg(ctx context.Context, db DBTX, arg ListSeatSnapshotsForOrgParams) ([]BillingSeatSnapshot, error)
2328
 	MarkCanceled(ctx context.Context, db DBTX, arg MarkCanceledParams) (MarkCanceledRow, error)
internal/billing/stripebilling/client.goadded
@@ -0,0 +1,244 @@
1
+// SPDX-License-Identifier: AGPL-3.0-or-later
2
+
3
+// Package stripebilling contains the Stripe-specific edge of the billing
4
+// system. Local subscription state stays in internal/billing; this package
5
+// owns hosted Checkout, Billing Portal, seat quantity updates, and webhook
6
+// signature verification.
7
+package stripebilling
8
+
9
+import (
10
+	"context"
11
+	"errors"
12
+	"fmt"
13
+	"strconv"
14
+	"strings"
15
+
16
+	stripeapi "github.com/stripe/stripe-go/v85"
17
+	"github.com/stripe/stripe-go/v85/webhook"
18
+)
19
+
20
+const (
21
+	MetadataOrgID   = "shithub_org_id"
22
+	MetadataOrgSlug = "shithub_org_slug"
23
+)
24
+
25
+var (
26
+	ErrSecretKeyRequired     = errors.New("stripe billing: secret key is required")
27
+	ErrWebhookSecretRequired = errors.New("stripe billing: webhook secret is required")
28
+	ErrTeamPriceRequired     = errors.New("stripe billing: team price id is required")
29
+	ErrCustomerIDRequired    = errors.New("stripe billing: customer id is required")
30
+	ErrSubscriptionItemID    = errors.New("stripe billing: subscription item id is required")
31
+	ErrURLRequired           = errors.New("stripe billing: redirect url is required")
32
+)
33
+
34
+type Config struct {
35
+	SecretKey     string
36
+	WebhookSecret string
37
+	TeamPriceID   string
38
+	AutomaticTax  bool
39
+}
40
+
41
+type Remote interface {
42
+	CreateCustomer(context.Context, CustomerInput) (Customer, error)
43
+	CreateCheckoutSession(context.Context, CheckoutInput) (CheckoutSession, error)
44
+	CreatePortalSession(context.Context, PortalInput) (PortalSession, error)
45
+	UpdateSubscriptionItemQuantity(context.Context, SeatQuantityInput) error
46
+	VerifyWebhook(payload []byte, signatureHeader string) (stripeapi.Event, error)
47
+}
48
+
49
+type Client struct {
50
+	stripe        *stripeapi.Client
51
+	webhookSecret string
52
+	teamPriceID   string
53
+	automaticTax  bool
54
+}
55
+
56
+type CustomerInput struct {
57
+	OrgID   int64
58
+	OrgSlug string
59
+	OrgName string
60
+	Email   string
61
+}
62
+
63
+type Customer struct {
64
+	ID string
65
+}
66
+
67
+type CheckoutInput struct {
68
+	OrgID      int64
69
+	OrgSlug    string
70
+	CustomerID string
71
+	SeatCount  int64
72
+	SuccessURL string
73
+	CancelURL  string
74
+}
75
+
76
+type CheckoutSession struct {
77
+	ID  string
78
+	URL string
79
+}
80
+
81
+type PortalInput struct {
82
+	CustomerID string
83
+	ReturnURL  string
84
+}
85
+
86
+type PortalSession struct {
87
+	ID  string
88
+	URL string
89
+}
90
+
91
+type SeatQuantityInput struct {
92
+	OrgID              int64
93
+	SubscriptionItemID string
94
+	Quantity           int64
95
+}
96
+
97
+func New(cfg Config) (*Client, error) {
98
+	cfg.SecretKey = strings.TrimSpace(cfg.SecretKey)
99
+	cfg.WebhookSecret = strings.TrimSpace(cfg.WebhookSecret)
100
+	cfg.TeamPriceID = strings.TrimSpace(cfg.TeamPriceID)
101
+	if cfg.SecretKey == "" {
102
+		return nil, ErrSecretKeyRequired
103
+	}
104
+	if cfg.WebhookSecret == "" {
105
+		return nil, ErrWebhookSecretRequired
106
+	}
107
+	if cfg.TeamPriceID == "" {
108
+		return nil, ErrTeamPriceRequired
109
+	}
110
+	return &Client{
111
+		stripe:        stripeapi.NewClient(cfg.SecretKey),
112
+		webhookSecret: cfg.WebhookSecret,
113
+		teamPriceID:   cfg.TeamPriceID,
114
+		automaticTax:  cfg.AutomaticTax,
115
+	}, nil
116
+}
117
+
118
+func (c *Client) CreateCustomer(ctx context.Context, in CustomerInput) (Customer, error) {
119
+	name := strings.TrimSpace(in.OrgName)
120
+	if name == "" {
121
+		name = strings.TrimSpace(in.OrgSlug)
122
+	}
123
+	params := &stripeapi.CustomerCreateParams{
124
+		Name:        stripeapi.String(name),
125
+		Description: stripeapi.String(fmt.Sprintf("shithub organization %s", strings.TrimSpace(in.OrgSlug))),
126
+		Metadata:    orgMetadata(in.OrgID, in.OrgSlug),
127
+	}
128
+	if email := strings.TrimSpace(in.Email); email != "" {
129
+		params.Email = stripeapi.String(email)
130
+	}
131
+	params.SetIdempotencyKey(idempotencyKey("customer", in.OrgID, "v1"))
132
+	customer, err := c.stripe.V1Customers.Create(ctx, params)
133
+	if err != nil {
134
+		return Customer{}, err
135
+	}
136
+	return Customer{ID: customer.ID}, nil
137
+}
138
+
139
+func (c *Client) CreateCheckoutSession(ctx context.Context, in CheckoutInput) (CheckoutSession, error) {
140
+	in.CustomerID = strings.TrimSpace(in.CustomerID)
141
+	if in.CustomerID == "" {
142
+		return CheckoutSession{}, ErrCustomerIDRequired
143
+	}
144
+	in.SuccessURL = strings.TrimSpace(in.SuccessURL)
145
+	if in.SuccessURL == "" {
146
+		return CheckoutSession{}, fmt.Errorf("%w: success_url", ErrURLRequired)
147
+	}
148
+	in.CancelURL = strings.TrimSpace(in.CancelURL)
149
+	if in.CancelURL == "" {
150
+		return CheckoutSession{}, fmt.Errorf("%w: cancel_url", ErrURLRequired)
151
+	}
152
+	if in.SeatCount < 1 {
153
+		in.SeatCount = 1
154
+	}
155
+	metadata := orgMetadata(in.OrgID, in.OrgSlug)
156
+	mode := string(stripeapi.CheckoutSessionModeSubscription)
157
+	paymentMethodCollection := string(stripeapi.CheckoutSessionPaymentMethodCollectionAlways)
158
+	billingAddressCollection := string(stripeapi.CheckoutSessionBillingAddressCollectionAuto)
159
+	params := &stripeapi.CheckoutSessionCreateParams{
160
+		Mode:                     stripeapi.String(mode),
161
+		Customer:                 stripeapi.String(in.CustomerID),
162
+		ClientReferenceID:        stripeapi.String(strconv.FormatInt(in.OrgID, 10)),
163
+		SuccessURL:               stripeapi.String(in.SuccessURL),
164
+		CancelURL:                stripeapi.String(in.CancelURL),
165
+		PaymentMethodCollection:  stripeapi.String(paymentMethodCollection),
166
+		BillingAddressCollection: stripeapi.String(billingAddressCollection),
167
+		LineItems: []*stripeapi.CheckoutSessionCreateLineItemParams{{
168
+			Price:    stripeapi.String(c.teamPriceID),
169
+			Quantity: stripeapi.Int64(in.SeatCount),
170
+		}},
171
+		Metadata: metadata,
172
+		SubscriptionData: &stripeapi.CheckoutSessionCreateSubscriptionDataParams{
173
+			Metadata: metadata,
174
+		},
175
+	}
176
+	if c.automaticTax {
177
+		params.AutomaticTax = &stripeapi.CheckoutSessionCreateAutomaticTaxParams{
178
+			Enabled: stripeapi.Bool(true),
179
+		}
180
+	}
181
+	params.SetIdempotencyKey(idempotencyKey("checkout", in.OrgID, "team", strconv.FormatInt(in.SeatCount, 10)))
182
+	session, err := c.stripe.V1CheckoutSessions.Create(ctx, params)
183
+	if err != nil {
184
+		return CheckoutSession{}, err
185
+	}
186
+	return CheckoutSession{ID: session.ID, URL: session.URL}, nil
187
+}
188
+
189
+func (c *Client) CreatePortalSession(ctx context.Context, in PortalInput) (PortalSession, error) {
190
+	in.CustomerID = strings.TrimSpace(in.CustomerID)
191
+	if in.CustomerID == "" {
192
+		return PortalSession{}, ErrCustomerIDRequired
193
+	}
194
+	in.ReturnURL = strings.TrimSpace(in.ReturnURL)
195
+	if in.ReturnURL == "" {
196
+		return PortalSession{}, fmt.Errorf("%w: portal_return_url", ErrURLRequired)
197
+	}
198
+	params := &stripeapi.BillingPortalSessionCreateParams{
199
+		Customer:  stripeapi.String(in.CustomerID),
200
+		ReturnURL: stripeapi.String(in.ReturnURL),
201
+	}
202
+	session, err := c.stripe.V1BillingPortalSessions.Create(ctx, params)
203
+	if err != nil {
204
+		return PortalSession{}, err
205
+	}
206
+	return PortalSession{ID: session.ID, URL: session.URL}, nil
207
+}
208
+
209
+func (c *Client) UpdateSubscriptionItemQuantity(ctx context.Context, in SeatQuantityInput) error {
210
+	in.SubscriptionItemID = strings.TrimSpace(in.SubscriptionItemID)
211
+	if in.SubscriptionItemID == "" {
212
+		return ErrSubscriptionItemID
213
+	}
214
+	if in.Quantity < 1 {
215
+		in.Quantity = 1
216
+	}
217
+	params := &stripeapi.SubscriptionItemUpdateParams{
218
+		Quantity: stripeapi.Int64(in.Quantity),
219
+	}
220
+	params.SetIdempotencyKey(idempotencyKey("seat-sync", in.OrgID, in.SubscriptionItemID, strconv.FormatInt(in.Quantity, 10)))
221
+	_, err := c.stripe.V1SubscriptionItems.Update(ctx, in.SubscriptionItemID, params)
222
+	return err
223
+}
224
+
225
+func (c *Client) VerifyWebhook(payload []byte, signatureHeader string) (stripeapi.Event, error) {
226
+	return webhook.ConstructEvent(payload, signatureHeader, c.webhookSecret)
227
+}
228
+
229
+func orgMetadata(orgID int64, orgSlug string) map[string]string {
230
+	return map[string]string{
231
+		MetadataOrgID:   strconv.FormatInt(orgID, 10),
232
+		MetadataOrgSlug: strings.TrimSpace(orgSlug),
233
+	}
234
+}
235
+
236
+func idempotencyKey(parts ...any) string {
237
+	var b strings.Builder
238
+	b.WriteString("shithub")
239
+	for _, part := range parts {
240
+		b.WriteByte(':')
241
+		b.WriteString(strings.NewReplacer(":", "_", " ", "_", "/", "_").Replace(fmt.Sprint(part)))
242
+	}
243
+	return b.String()
244
+}
internal/billing/stripebilling/client_test.goadded
@@ -0,0 +1,52 @@
1
+// SPDX-License-Identifier: AGPL-3.0-or-later
2
+
3
+package stripebilling
4
+
5
+import (
6
+	"errors"
7
+	"fmt"
8
+	"testing"
9
+
10
+	stripeapi "github.com/stripe/stripe-go/v85"
11
+	"github.com/stripe/stripe-go/v85/webhook"
12
+)
13
+
14
+func TestNewValidatesRequiredConfig(t *testing.T) {
15
+	t.Parallel()
16
+	if _, err := New(Config{}); !errors.Is(err, ErrSecretKeyRequired) {
17
+		t.Fatalf("New without secret key: got %v", err)
18
+	}
19
+	if _, err := New(Config{SecretKey: "sk_test_123"}); !errors.Is(err, ErrWebhookSecretRequired) {
20
+		t.Fatalf("New without webhook secret: got %v", err)
21
+	}
22
+	if _, err := New(Config{SecretKey: "sk_test_123", WebhookSecret: "whsec_123"}); !errors.Is(err, ErrTeamPriceRequired) {
23
+		t.Fatalf("New without price id: got %v", err)
24
+	}
25
+}
26
+
27
+func TestVerifyWebhookUsesSigningSecret(t *testing.T) {
28
+	t.Parallel()
29
+	client, err := New(Config{
30
+		SecretKey:     "sk_test_123",
31
+		WebhookSecret: "whsec_test",
32
+		TeamPriceID:   "price_123",
33
+	})
34
+	if err != nil {
35
+		t.Fatalf("New: %v", err)
36
+	}
37
+	payload := []byte(fmt.Sprintf(`{"id":"evt_test","object":"event","api_version":%q,"type":"customer.subscription.updated","data":{"object":{"id":"sub_test","object":"subscription"}}}`, stripeapi.APIVersion))
38
+	signed := webhook.GenerateTestSignedPayload(&webhook.UnsignedPayload{
39
+		Payload: payload,
40
+		Secret:  "whsec_test",
41
+	})
42
+	event, err := client.VerifyWebhook(payload, signed.Header)
43
+	if err != nil {
44
+		t.Fatalf("VerifyWebhook: %v", err)
45
+	}
46
+	if event.ID != "evt_test" || event.Type != "customer.subscription.updated" {
47
+		t.Fatalf("unexpected event: id=%s type=%s", event.ID, event.Type)
48
+	}
49
+	if _, err := client.VerifyWebhook(payload, "t=1,v1=bad"); err == nil {
50
+		t.Fatalf("VerifyWebhook accepted bad signature")
51
+	}
52
+}