Add Stripe billing adapter
Authored by
mfwolffe <wolffemf@dukes.jmu.edu>
- SHA
25f49fa40e630f05fc2c441f3001d5558a1a4aea- Parents
-
8f753ab - Tree
366a162
25f49fa
25f49fa40e630f05fc2c441f3001d5558a1a4aea8f753ab
366a162| Status | File | + | - |
|---|---|---|---|
| M |
internal/billing/billing.go
|
161 | 10 |
| M |
internal/billing/billing_test.go
|
76 | 0 |
| M |
internal/billing/queries/billing.sql
|
15 | 0 |
| M |
internal/billing/sqlc/billing.sql.go
|
83 | 0 |
| M |
internal/billing/sqlc/querier.go
|
5 | 0 |
| A |
internal/billing/stripebilling/client.go
|
244 | 0 |
| A |
internal/billing/stripebilling/client_test.go
|
52 | 0 |
internal/billing/billing.gomodified@@ -2,7 +2,7 @@ | ||
| 2 | 2 | |
| 3 | 3 | // Package billing owns local paid-organization state. It stores Stripe |
| 4 | 4 | // identifiers and derived subscription state, but it does not call |
| 5 | -// Stripe directly; webhook/API integration lands in SP03. | |
| 5 | +// Stripe directly; Stripe API details stay in the SP03 adapter layer. | |
| 6 | 6 | package billing |
| 7 | 7 | |
| 8 | 8 | import ( |
@@ -27,6 +27,7 @@ type Deps struct { | ||
| 27 | 27 | type ( |
| 28 | 28 | Plan = billingdb.OrgPlan |
| 29 | 29 | SubscriptionStatus = billingdb.BillingSubscriptionStatus |
| 30 | + InvoiceStatus = billingdb.BillingInvoiceStatus | |
| 30 | 31 | State = billingdb.OrgBillingState |
| 31 | 32 | ) |
| 32 | 33 | |
@@ -43,18 +44,27 @@ const ( | ||
| 43 | 44 | SubscriptionStatusCanceled = billingdb.BillingSubscriptionStatusCanceled |
| 44 | 45 | SubscriptionStatusUnpaid = billingdb.BillingSubscriptionStatusUnpaid |
| 45 | 46 | SubscriptionStatusPaused = billingdb.BillingSubscriptionStatusPaused |
| 47 | + | |
| 48 | + InvoiceStatusDraft = billingdb.BillingInvoiceStatusDraft | |
| 49 | + InvoiceStatusOpen = billingdb.BillingInvoiceStatusOpen | |
| 50 | + InvoiceStatusPaid = billingdb.BillingInvoiceStatusPaid | |
| 51 | + InvoiceStatusVoid = billingdb.BillingInvoiceStatusVoid | |
| 52 | + InvoiceStatusUncollectible = billingdb.BillingInvoiceStatusUncollectible | |
| 46 | 53 | ) |
| 47 | 54 | |
| 48 | 55 | var ( |
| 49 | - ErrPoolRequired = errors.New("billing: pool is required") | |
| 50 | - ErrOrgIDRequired = errors.New("billing: org id is required") | |
| 51 | - ErrStripeCustomerID = errors.New("billing: stripe customer id is required") | |
| 52 | - ErrInvalidPlan = errors.New("billing: invalid plan") | |
| 53 | - ErrInvalidStatus = errors.New("billing: invalid subscription status") | |
| 54 | - ErrInvalidSeatCount = errors.New("billing: seat counts cannot be negative") | |
| 55 | - ErrWebhookEventID = errors.New("billing: webhook event id is required") | |
| 56 | - ErrWebhookEventType = errors.New("billing: webhook event type is required") | |
| 57 | - ErrWebhookPayload = errors.New("billing: webhook payload must be a JSON object") | |
| 56 | + ErrPoolRequired = errors.New("billing: pool is required") | |
| 57 | + ErrOrgIDRequired = errors.New("billing: org id is required") | |
| 58 | + ErrStripeCustomerID = errors.New("billing: stripe customer id is required") | |
| 59 | + ErrStripeSubscriptionID = errors.New("billing: stripe subscription id is required") | |
| 60 | + ErrStripeInvoiceID = errors.New("billing: stripe invoice id is required") | |
| 61 | + ErrInvalidPlan = errors.New("billing: invalid plan") | |
| 62 | + ErrInvalidStatus = errors.New("billing: invalid subscription status") | |
| 63 | + ErrInvalidInvoiceStatus = errors.New("billing: invalid invoice status") | |
| 64 | + ErrInvalidSeatCount = errors.New("billing: seat counts cannot be negative") | |
| 65 | + ErrWebhookEventID = errors.New("billing: webhook event id is required") | |
| 66 | + ErrWebhookEventType = errors.New("billing: webhook event type is required") | |
| 67 | + ErrWebhookPayload = errors.New("billing: webhook payload must be a JSON object") | |
| 58 | 68 | ) |
| 59 | 69 | |
| 60 | 70 | // SubscriptionSnapshot is the local projection of a provider |
@@ -88,6 +98,26 @@ type WebhookEvent struct { | ||
| 88 | 98 | Payload []byte |
| 89 | 99 | } |
| 90 | 100 | |
| 101 | +type InvoiceSnapshot struct { | |
| 102 | + OrgID int64 | |
| 103 | + StripeInvoiceID string | |
| 104 | + StripeCustomerID string | |
| 105 | + StripeSubscriptionID string | |
| 106 | + Status InvoiceStatus | |
| 107 | + Number string | |
| 108 | + Currency string | |
| 109 | + AmountDueCents int64 | |
| 110 | + AmountPaidCents int64 | |
| 111 | + AmountRemainingCents int64 | |
| 112 | + HostedInvoiceURL string | |
| 113 | + InvoicePDFURL string | |
| 114 | + PeriodStart time.Time | |
| 115 | + PeriodEnd time.Time | |
| 116 | + DueAt time.Time | |
| 117 | + PaidAt time.Time | |
| 118 | + VoidedAt time.Time | |
| 119 | +} | |
| 120 | + | |
| 91 | 121 | func GetOrgBillingState(ctx context.Context, deps Deps, orgID int64) (State, error) { |
| 92 | 122 | if err := validateDeps(deps); err != nil { |
| 93 | 123 | return State{}, err |
@@ -98,6 +128,28 @@ func GetOrgBillingState(ctx context.Context, deps Deps, orgID int64) (State, err | ||
| 98 | 128 | return billingdb.New().GetOrgBillingState(ctx, deps.Pool, orgID) |
| 99 | 129 | } |
| 100 | 130 | |
| 131 | +func GetOrgBillingStateByStripeCustomer(ctx context.Context, deps Deps, customerID string) (State, error) { | |
| 132 | + if err := validateDeps(deps); err != nil { | |
| 133 | + return State{}, err | |
| 134 | + } | |
| 135 | + customerID = strings.TrimSpace(customerID) | |
| 136 | + if customerID == "" { | |
| 137 | + return State{}, ErrStripeCustomerID | |
| 138 | + } | |
| 139 | + return billingdb.New().GetOrgBillingStateByStripeCustomer(ctx, deps.Pool, pgText(customerID)) | |
| 140 | +} | |
| 141 | + | |
| 142 | +func GetOrgBillingStateByStripeSubscription(ctx context.Context, deps Deps, subscriptionID string) (State, error) { | |
| 143 | + if err := validateDeps(deps); err != nil { | |
| 144 | + return State{}, err | |
| 145 | + } | |
| 146 | + subscriptionID = strings.TrimSpace(subscriptionID) | |
| 147 | + if subscriptionID == "" { | |
| 148 | + return State{}, ErrStripeSubscriptionID | |
| 149 | + } | |
| 150 | + return billingdb.New().GetOrgBillingStateByStripeSubscription(ctx, deps.Pool, pgText(subscriptionID)) | |
| 151 | +} | |
| 152 | + | |
| 101 | 153 | func SetStripeCustomer(ctx context.Context, deps Deps, orgID int64, customerID string) (State, error) { |
| 102 | 154 | if err := validateDeps(deps); err != nil { |
| 103 | 155 | return State{}, err |
@@ -178,6 +230,78 @@ func RecordWebhookEvent(ctx context.Context, deps Deps, event WebhookEvent) (bil | ||
| 178 | 230 | return row, true, nil |
| 179 | 231 | } |
| 180 | 232 | |
| 233 | +func MarkWebhookEventProcessed(ctx context.Context, deps Deps, providerEventID string) (billingdb.BillingWebhookEvent, error) { | |
| 234 | + if err := validateDeps(deps); err != nil { | |
| 235 | + return billingdb.BillingWebhookEvent{}, err | |
| 236 | + } | |
| 237 | + providerEventID = strings.TrimSpace(providerEventID) | |
| 238 | + if providerEventID == "" { | |
| 239 | + return billingdb.BillingWebhookEvent{}, ErrWebhookEventID | |
| 240 | + } | |
| 241 | + return billingdb.New().MarkWebhookEventProcessed(ctx, deps.Pool, providerEventID) | |
| 242 | +} | |
| 243 | + | |
| 244 | +func MarkWebhookEventFailed(ctx context.Context, deps Deps, providerEventID, processError string) (billingdb.BillingWebhookEvent, error) { | |
| 245 | + if err := validateDeps(deps); err != nil { | |
| 246 | + return billingdb.BillingWebhookEvent{}, err | |
| 247 | + } | |
| 248 | + providerEventID = strings.TrimSpace(providerEventID) | |
| 249 | + if providerEventID == "" { | |
| 250 | + return billingdb.BillingWebhookEvent{}, ErrWebhookEventID | |
| 251 | + } | |
| 252 | + processError = strings.TrimSpace(processError) | |
| 253 | + if len(processError) > 2000 { | |
| 254 | + processError = processError[:2000] | |
| 255 | + } | |
| 256 | + return billingdb.New().MarkWebhookEventFailed(ctx, deps.Pool, billingdb.MarkWebhookEventFailedParams{ | |
| 257 | + ProviderEventID: providerEventID, | |
| 258 | + ProcessError: processError, | |
| 259 | + }) | |
| 260 | +} | |
| 261 | + | |
| 262 | +func UpsertInvoice(ctx context.Context, deps Deps, snap InvoiceSnapshot) (billingdb.BillingInvoice, error) { | |
| 263 | + if err := validateDeps(deps); err != nil { | |
| 264 | + return billingdb.BillingInvoice{}, err | |
| 265 | + } | |
| 266 | + if snap.OrgID == 0 { | |
| 267 | + return billingdb.BillingInvoice{}, ErrOrgIDRequired | |
| 268 | + } | |
| 269 | + snap.StripeInvoiceID = strings.TrimSpace(snap.StripeInvoiceID) | |
| 270 | + if snap.StripeInvoiceID == "" { | |
| 271 | + return billingdb.BillingInvoice{}, ErrStripeInvoiceID | |
| 272 | + } | |
| 273 | + snap.StripeCustomerID = strings.TrimSpace(snap.StripeCustomerID) | |
| 274 | + if snap.StripeCustomerID == "" { | |
| 275 | + return billingdb.BillingInvoice{}, ErrStripeCustomerID | |
| 276 | + } | |
| 277 | + if !validInvoiceStatus(snap.Status) { | |
| 278 | + return billingdb.BillingInvoice{}, fmt.Errorf("%w: %q", ErrInvalidInvoiceStatus, snap.Status) | |
| 279 | + } | |
| 280 | + row, err := billingdb.New().UpsertInvoice(ctx, deps.Pool, billingdb.UpsertInvoiceParams{ | |
| 281 | + OrgID: snap.OrgID, | |
| 282 | + StripeInvoiceID: snap.StripeInvoiceID, | |
| 283 | + StripeCustomerID: snap.StripeCustomerID, | |
| 284 | + StripeSubscriptionID: pgText(snap.StripeSubscriptionID), | |
| 285 | + Status: snap.Status, | |
| 286 | + Number: strings.TrimSpace(snap.Number), | |
| 287 | + Currency: strings.ToLower(strings.TrimSpace(snap.Currency)), | |
| 288 | + AmountDueCents: snap.AmountDueCents, | |
| 289 | + AmountPaidCents: snap.AmountPaidCents, | |
| 290 | + AmountRemainingCents: snap.AmountRemainingCents, | |
| 291 | + HostedInvoiceUrl: strings.TrimSpace(snap.HostedInvoiceURL), | |
| 292 | + InvoicePdfUrl: strings.TrimSpace(snap.InvoicePDFURL), | |
| 293 | + PeriodStart: pgTime(snap.PeriodStart), | |
| 294 | + PeriodEnd: pgTime(snap.PeriodEnd), | |
| 295 | + DueAt: pgTime(snap.DueAt), | |
| 296 | + PaidAt: pgTime(snap.PaidAt), | |
| 297 | + VoidedAt: pgTime(snap.VoidedAt), | |
| 298 | + }) | |
| 299 | + if err != nil { | |
| 300 | + return billingdb.BillingInvoice{}, err | |
| 301 | + } | |
| 302 | + return row, nil | |
| 303 | +} | |
| 304 | + | |
| 181 | 305 | func SyncSeatSnapshot(ctx context.Context, deps Deps, snap SeatSnapshot) (billingdb.BillingSeatSnapshot, error) { |
| 182 | 306 | if err := validateDeps(deps); err != nil { |
| 183 | 307 | return billingdb.BillingSeatSnapshot{}, err |
@@ -205,6 +329,20 @@ func SyncSeatSnapshot(ctx context.Context, deps Deps, snap SeatSnapshot) (billin | ||
| 205 | 329 | return billingdb.BillingSeatSnapshot(row), nil |
| 206 | 330 | } |
| 207 | 331 | |
| 332 | +func CountBillableOrgMembers(ctx context.Context, deps Deps, orgID int64) (int, error) { | |
| 333 | + if err := validateDeps(deps); err != nil { | |
| 334 | + return 0, err | |
| 335 | + } | |
| 336 | + if orgID == 0 { | |
| 337 | + return 0, ErrOrgIDRequired | |
| 338 | + } | |
| 339 | + n, err := billingdb.New().CountBillableOrgMembers(ctx, deps.Pool, orgID) | |
| 340 | + if err != nil { | |
| 341 | + return 0, err | |
| 342 | + } | |
| 343 | + return int(n), nil | |
| 344 | +} | |
| 345 | + | |
| 208 | 346 | func MarkPastDue(ctx context.Context, deps Deps, orgID int64, graceUntil time.Time, lastWebhookEventID string) (State, error) { |
| 209 | 347 | if err := validateDeps(deps); err != nil { |
| 210 | 348 | return State{}, err |
@@ -282,6 +420,19 @@ func validStatus(status SubscriptionStatus) bool { | ||
| 282 | 420 | } |
| 283 | 421 | } |
| 284 | 422 | |
| 423 | +func validInvoiceStatus(status InvoiceStatus) bool { | |
| 424 | + switch status { | |
| 425 | + case InvoiceStatusDraft, | |
| 426 | + InvoiceStatusOpen, | |
| 427 | + InvoiceStatusPaid, | |
| 428 | + InvoiceStatusVoid, | |
| 429 | + InvoiceStatusUncollectible: | |
| 430 | + return true | |
| 431 | + default: | |
| 432 | + return false | |
| 433 | + } | |
| 434 | +} | |
| 435 | + | |
| 285 | 436 | func pgText(s string) pgtype.Text { |
| 286 | 437 | s = strings.TrimSpace(s) |
| 287 | 438 | return pgtype.Text{String: s, Valid: s != ""} |
internal/billing/billing_test.gomodified@@ -162,6 +162,13 @@ func TestRecordWebhookEventIsIdempotent(t *testing.T) { | ||
| 162 | 162 | if created { |
| 163 | 163 | t.Fatalf("duplicate receipt should not be created") |
| 164 | 164 | } |
| 165 | + | |
| 166 | + if _, err := billing.MarkWebhookEventProcessed(ctx, deps, event.ProviderEventID); err != nil { | |
| 167 | + t.Fatalf("MarkWebhookEventProcessed: %v", err) | |
| 168 | + } | |
| 169 | + if _, err := billing.MarkWebhookEventFailed(ctx, deps, event.ProviderEventID, "late duplicate"); err != nil { | |
| 170 | + t.Fatalf("MarkWebhookEventFailed: %v", err) | |
| 171 | + } | |
| 165 | 172 | } |
| 166 | 173 | |
| 167 | 174 | func TestSyncSeatSnapshotUpdatesBillingState(t *testing.T) { |
@@ -187,6 +194,75 @@ func TestSyncSeatSnapshotUpdatesBillingState(t *testing.T) { | ||
| 187 | 194 | if state.BillableSeats != 2 || !state.SeatSnapshotAt.Valid { |
| 188 | 195 | t.Fatalf("state did not record seat snapshot: %+v", state) |
| 189 | 196 | } |
| 197 | + | |
| 198 | + count, err := billing.CountBillableOrgMembers(ctx, deps, org.ID) | |
| 199 | + if err != nil { | |
| 200 | + t.Fatalf("CountBillableOrgMembers: %v", err) | |
| 201 | + } | |
| 202 | + if count != 1 { | |
| 203 | + t.Fatalf("billable members: got %d, want 1", count) | |
| 204 | + } | |
| 205 | +} | |
| 206 | + | |
| 207 | +func TestStripeLookupsAndInvoiceSnapshot(t *testing.T) { | |
| 208 | + _, deps, org := setup(t) | |
| 209 | + ctx := context.Background() | |
| 210 | + | |
| 211 | + start := time.Now().UTC().Truncate(time.Second) | |
| 212 | + if _, err := billing.SetStripeCustomer(ctx, deps, org.ID, "cus_lookup"); err != nil { | |
| 213 | + t.Fatalf("SetStripeCustomer: %v", err) | |
| 214 | + } | |
| 215 | + if _, err := billing.ApplySubscriptionSnapshot(ctx, deps, billing.SubscriptionSnapshot{ | |
| 216 | + OrgID: org.ID, | |
| 217 | + Plan: billing.PlanTeam, | |
| 218 | + Status: billing.SubscriptionStatusActive, | |
| 219 | + StripeSubscriptionID: "sub_lookup", | |
| 220 | + StripeSubscriptionItemID: "si_lookup", | |
| 221 | + CurrentPeriodStart: start, | |
| 222 | + CurrentPeriodEnd: start.Add(30 * 24 * time.Hour), | |
| 223 | + LastWebhookEventID: "evt_lookup", | |
| 224 | + }); err != nil { | |
| 225 | + t.Fatalf("ApplySubscriptionSnapshot: %v", err) | |
| 226 | + } | |
| 227 | + | |
| 228 | + byCustomer, err := billing.GetOrgBillingStateByStripeCustomer(ctx, deps, "cus_lookup") | |
| 229 | + if err != nil { | |
| 230 | + t.Fatalf("GetOrgBillingStateByStripeCustomer: %v", err) | |
| 231 | + } | |
| 232 | + if byCustomer.OrgID != org.ID { | |
| 233 | + t.Fatalf("customer lookup org_id: got %d, want %d", byCustomer.OrgID, org.ID) | |
| 234 | + } | |
| 235 | + bySubscription, err := billing.GetOrgBillingStateByStripeSubscription(ctx, deps, "sub_lookup") | |
| 236 | + if err != nil { | |
| 237 | + t.Fatalf("GetOrgBillingStateByStripeSubscription: %v", err) | |
| 238 | + } | |
| 239 | + if bySubscription.OrgID != org.ID { | |
| 240 | + t.Fatalf("subscription lookup org_id: got %d, want %d", bySubscription.OrgID, org.ID) | |
| 241 | + } | |
| 242 | + | |
| 243 | + invoice, err := billing.UpsertInvoice(ctx, deps, billing.InvoiceSnapshot{ | |
| 244 | + OrgID: org.ID, | |
| 245 | + StripeInvoiceID: "in_lookup", | |
| 246 | + StripeCustomerID: "cus_lookup", | |
| 247 | + StripeSubscriptionID: "sub_lookup", | |
| 248 | + Status: billing.InvoiceStatusPaid, | |
| 249 | + Number: "SHI-0001", | |
| 250 | + Currency: "USD", | |
| 251 | + AmountDueCents: 1200, | |
| 252 | + AmountPaidCents: 1200, | |
| 253 | + AmountRemainingCents: 0, | |
| 254 | + HostedInvoiceURL: "https://invoice.stripe.test/i", | |
| 255 | + InvoicePDFURL: "https://invoice.stripe.test/i.pdf", | |
| 256 | + PeriodStart: start, | |
| 257 | + PeriodEnd: start.Add(30 * 24 * time.Hour), | |
| 258 | + PaidAt: start.Add(time.Minute), | |
| 259 | + }) | |
| 260 | + if err != nil { | |
| 261 | + t.Fatalf("UpsertInvoice: %v", err) | |
| 262 | + } | |
| 263 | + if invoice.StripeInvoiceID != "in_lookup" || invoice.Status != billing.InvoiceStatusPaid || invoice.Currency != "usd" { | |
| 264 | + t.Fatalf("unexpected invoice: %+v", invoice) | |
| 265 | + } | |
| 190 | 266 | } |
| 191 | 267 | |
| 192 | 268 | func assertState(t *testing.T, state billing.State, plan billing.Plan, status billing.SubscriptionStatus) { |
internal/billing/queries/billing.sqlmodified@@ -5,6 +5,16 @@ | ||
| 5 | 5 | -- name: GetOrgBillingState :one |
| 6 | 6 | SELECT * FROM org_billing_states WHERE org_id = $1; |
| 7 | 7 | |
| 8 | +-- name: GetOrgBillingStateByStripeCustomer :one | |
| 9 | +SELECT * FROM org_billing_states | |
| 10 | +WHERE provider = 'stripe' | |
| 11 | + AND stripe_customer_id = $1; | |
| 12 | + | |
| 13 | +-- name: GetOrgBillingStateByStripeSubscription :one | |
| 14 | +SELECT * FROM org_billing_states | |
| 15 | +WHERE provider = 'stripe' | |
| 16 | + AND stripe_subscription_id = $1; | |
| 17 | + | |
| 8 | 18 | -- name: SetStripeCustomer :one |
| 9 | 19 | INSERT INTO org_billing_states (org_id, provider, stripe_customer_id) |
| 10 | 20 | VALUES ($1, 'stripe', $2) |
@@ -184,6 +194,11 @@ WHERE org_id = $1 | ||
| 184 | 194 | ORDER BY captured_at DESC, id DESC |
| 185 | 195 | LIMIT $2; |
| 186 | 196 | |
| 197 | +-- name: CountBillableOrgMembers :one | |
| 198 | +SELECT count(*)::integer | |
| 199 | +FROM org_members | |
| 200 | +WHERE org_id = $1; | |
| 201 | + | |
| 187 | 202 | -- ─── billing_invoices ────────────────────────────────────────────── |
| 188 | 203 | |
| 189 | 204 | -- name: UpsertInvoice :one |
internal/billing/sqlc/billing.sql.gomodified@@ -242,6 +242,19 @@ func (q *Queries) ClearBillingLock(ctx context.Context, db DBTX, orgID int64) (C | ||
| 242 | 242 | return i, err |
| 243 | 243 | } |
| 244 | 244 | |
| 245 | +const countBillableOrgMembers = `-- name: CountBillableOrgMembers :one | |
| 246 | +SELECT count(*)::integer | |
| 247 | +FROM org_members | |
| 248 | +WHERE org_id = $1 | |
| 249 | +` | |
| 250 | + | |
| 251 | +func (q *Queries) CountBillableOrgMembers(ctx context.Context, db DBTX, orgID int64) (int32, error) { | |
| 252 | + row := db.QueryRow(ctx, countBillableOrgMembers, orgID) | |
| 253 | + var column_1 int32 | |
| 254 | + err := row.Scan(&column_1) | |
| 255 | + return column_1, err | |
| 256 | +} | |
| 257 | + | |
| 245 | 258 | const createSeatSnapshot = `-- name: CreateSeatSnapshot :one |
| 246 | 259 | |
| 247 | 260 | WITH snapshot AS ( |
@@ -404,6 +417,76 @@ func (q *Queries) GetOrgBillingState(ctx context.Context, db DBTX, orgID int64) | ||
| 404 | 417 | return i, err |
| 405 | 418 | } |
| 406 | 419 | |
| 420 | +const getOrgBillingStateByStripeCustomer = `-- name: GetOrgBillingStateByStripeCustomer :one | |
| 421 | +SELECT org_id, provider, stripe_customer_id, stripe_subscription_id, stripe_subscription_item_id, plan, subscription_status, billable_seats, seat_snapshot_at, current_period_start, current_period_end, cancel_at_period_end, trial_end, past_due_at, canceled_at, locked_at, lock_reason, grace_until, last_webhook_event_id, created_at, updated_at FROM org_billing_states | |
| 422 | +WHERE provider = 'stripe' | |
| 423 | + AND stripe_customer_id = $1 | |
| 424 | +` | |
| 425 | + | |
| 426 | +func (q *Queries) GetOrgBillingStateByStripeCustomer(ctx context.Context, db DBTX, stripeCustomerID pgtype.Text) (OrgBillingState, error) { | |
| 427 | + row := db.QueryRow(ctx, getOrgBillingStateByStripeCustomer, stripeCustomerID) | |
| 428 | + var i OrgBillingState | |
| 429 | + err := row.Scan( | |
| 430 | + &i.OrgID, | |
| 431 | + &i.Provider, | |
| 432 | + &i.StripeCustomerID, | |
| 433 | + &i.StripeSubscriptionID, | |
| 434 | + &i.StripeSubscriptionItemID, | |
| 435 | + &i.Plan, | |
| 436 | + &i.SubscriptionStatus, | |
| 437 | + &i.BillableSeats, | |
| 438 | + &i.SeatSnapshotAt, | |
| 439 | + &i.CurrentPeriodStart, | |
| 440 | + &i.CurrentPeriodEnd, | |
| 441 | + &i.CancelAtPeriodEnd, | |
| 442 | + &i.TrialEnd, | |
| 443 | + &i.PastDueAt, | |
| 444 | + &i.CanceledAt, | |
| 445 | + &i.LockedAt, | |
| 446 | + &i.LockReason, | |
| 447 | + &i.GraceUntil, | |
| 448 | + &i.LastWebhookEventID, | |
| 449 | + &i.CreatedAt, | |
| 450 | + &i.UpdatedAt, | |
| 451 | + ) | |
| 452 | + return i, err | |
| 453 | +} | |
| 454 | + | |
| 455 | +const getOrgBillingStateByStripeSubscription = `-- name: GetOrgBillingStateByStripeSubscription :one | |
| 456 | +SELECT org_id, provider, stripe_customer_id, stripe_subscription_id, stripe_subscription_item_id, plan, subscription_status, billable_seats, seat_snapshot_at, current_period_start, current_period_end, cancel_at_period_end, trial_end, past_due_at, canceled_at, locked_at, lock_reason, grace_until, last_webhook_event_id, created_at, updated_at FROM org_billing_states | |
| 457 | +WHERE provider = 'stripe' | |
| 458 | + AND stripe_subscription_id = $1 | |
| 459 | +` | |
| 460 | + | |
| 461 | +func (q *Queries) GetOrgBillingStateByStripeSubscription(ctx context.Context, db DBTX, stripeSubscriptionID pgtype.Text) (OrgBillingState, error) { | |
| 462 | + row := db.QueryRow(ctx, getOrgBillingStateByStripeSubscription, stripeSubscriptionID) | |
| 463 | + var i OrgBillingState | |
| 464 | + err := row.Scan( | |
| 465 | + &i.OrgID, | |
| 466 | + &i.Provider, | |
| 467 | + &i.StripeCustomerID, | |
| 468 | + &i.StripeSubscriptionID, | |
| 469 | + &i.StripeSubscriptionItemID, | |
| 470 | + &i.Plan, | |
| 471 | + &i.SubscriptionStatus, | |
| 472 | + &i.BillableSeats, | |
| 473 | + &i.SeatSnapshotAt, | |
| 474 | + &i.CurrentPeriodStart, | |
| 475 | + &i.CurrentPeriodEnd, | |
| 476 | + &i.CancelAtPeriodEnd, | |
| 477 | + &i.TrialEnd, | |
| 478 | + &i.PastDueAt, | |
| 479 | + &i.CanceledAt, | |
| 480 | + &i.LockedAt, | |
| 481 | + &i.LockReason, | |
| 482 | + &i.GraceUntil, | |
| 483 | + &i.LastWebhookEventID, | |
| 484 | + &i.CreatedAt, | |
| 485 | + &i.UpdatedAt, | |
| 486 | + ) | |
| 487 | + return i, err | |
| 488 | +} | |
| 489 | + | |
| 407 | 490 | const listInvoicesForOrg = `-- name: ListInvoicesForOrg :many |
| 408 | 491 | SELECT id, org_id, provider, stripe_invoice_id, stripe_customer_id, stripe_subscription_id, status, number, currency, amount_due_cents, amount_paid_cents, amount_remaining_cents, hosted_invoice_url, invoice_pdf_url, period_start, period_end, due_at, paid_at, voided_at, created_at, updated_at FROM billing_invoices |
| 409 | 492 | WHERE org_id = $1 |
internal/billing/sqlc/querier.gomodified@@ -6,11 +6,14 @@ package billingdb | ||
| 6 | 6 | |
| 7 | 7 | import ( |
| 8 | 8 | "context" |
| 9 | + | |
| 10 | + "github.com/jackc/pgx/v5/pgtype" | |
| 9 | 11 | ) |
| 10 | 12 | |
| 11 | 13 | type Querier interface { |
| 12 | 14 | ApplySubscriptionSnapshot(ctx context.Context, db DBTX, arg ApplySubscriptionSnapshotParams) (ApplySubscriptionSnapshotRow, error) |
| 13 | 15 | ClearBillingLock(ctx context.Context, db DBTX, orgID int64) (ClearBillingLockRow, error) |
| 16 | + CountBillableOrgMembers(ctx context.Context, db DBTX, orgID int64) (int32, error) | |
| 14 | 17 | // ─── billing_seat_snapshots ──────────────────────────────────────── |
| 15 | 18 | CreateSeatSnapshot(ctx context.Context, db DBTX, arg CreateSeatSnapshotParams) (CreateSeatSnapshotRow, error) |
| 16 | 19 | // ─── billing_webhook_events ──────────────────────────────────────── |
@@ -18,6 +21,8 @@ type Querier interface { | ||
| 18 | 21 | // SPDX-License-Identifier: AGPL-3.0-or-later |
| 19 | 22 | // ─── org_billing_states ──────────────────────────────────────────── |
| 20 | 23 | GetOrgBillingState(ctx context.Context, db DBTX, orgID int64) (OrgBillingState, error) |
| 24 | + GetOrgBillingStateByStripeCustomer(ctx context.Context, db DBTX, stripeCustomerID pgtype.Text) (OrgBillingState, error) | |
| 25 | + GetOrgBillingStateByStripeSubscription(ctx context.Context, db DBTX, stripeSubscriptionID pgtype.Text) (OrgBillingState, error) | |
| 21 | 26 | ListInvoicesForOrg(ctx context.Context, db DBTX, arg ListInvoicesForOrgParams) ([]BillingInvoice, error) |
| 22 | 27 | ListSeatSnapshotsForOrg(ctx context.Context, db DBTX, arg ListSeatSnapshotsForOrgParams) ([]BillingSeatSnapshot, error) |
| 23 | 28 | MarkCanceled(ctx context.Context, db DBTX, arg MarkCanceledParams) (MarkCanceledRow, error) |
internal/billing/stripebilling/client.goadded@@ -0,0 +1,244 @@ | ||
| 1 | +// SPDX-License-Identifier: AGPL-3.0-or-later | |
| 2 | + | |
| 3 | +// Package stripebilling contains the Stripe-specific edge of the billing | |
| 4 | +// system. Local subscription state stays in internal/billing; this package | |
| 5 | +// owns hosted Checkout, Billing Portal, seat quantity updates, and webhook | |
| 6 | +// signature verification. | |
| 7 | +package stripebilling | |
| 8 | + | |
| 9 | +import ( | |
| 10 | + "context" | |
| 11 | + "errors" | |
| 12 | + "fmt" | |
| 13 | + "strconv" | |
| 14 | + "strings" | |
| 15 | + | |
| 16 | + stripeapi "github.com/stripe/stripe-go/v85" | |
| 17 | + "github.com/stripe/stripe-go/v85/webhook" | |
| 18 | +) | |
| 19 | + | |
| 20 | +const ( | |
| 21 | + MetadataOrgID = "shithub_org_id" | |
| 22 | + MetadataOrgSlug = "shithub_org_slug" | |
| 23 | +) | |
| 24 | + | |
| 25 | +var ( | |
| 26 | + ErrSecretKeyRequired = errors.New("stripe billing: secret key is required") | |
| 27 | + ErrWebhookSecretRequired = errors.New("stripe billing: webhook secret is required") | |
| 28 | + ErrTeamPriceRequired = errors.New("stripe billing: team price id is required") | |
| 29 | + ErrCustomerIDRequired = errors.New("stripe billing: customer id is required") | |
| 30 | + ErrSubscriptionItemID = errors.New("stripe billing: subscription item id is required") | |
| 31 | + ErrURLRequired = errors.New("stripe billing: redirect url is required") | |
| 32 | +) | |
| 33 | + | |
| 34 | +type Config struct { | |
| 35 | + SecretKey string | |
| 36 | + WebhookSecret string | |
| 37 | + TeamPriceID string | |
| 38 | + AutomaticTax bool | |
| 39 | +} | |
| 40 | + | |
| 41 | +type Remote interface { | |
| 42 | + CreateCustomer(context.Context, CustomerInput) (Customer, error) | |
| 43 | + CreateCheckoutSession(context.Context, CheckoutInput) (CheckoutSession, error) | |
| 44 | + CreatePortalSession(context.Context, PortalInput) (PortalSession, error) | |
| 45 | + UpdateSubscriptionItemQuantity(context.Context, SeatQuantityInput) error | |
| 46 | + VerifyWebhook(payload []byte, signatureHeader string) (stripeapi.Event, error) | |
| 47 | +} | |
| 48 | + | |
| 49 | +type Client struct { | |
| 50 | + stripe *stripeapi.Client | |
| 51 | + webhookSecret string | |
| 52 | + teamPriceID string | |
| 53 | + automaticTax bool | |
| 54 | +} | |
| 55 | + | |
| 56 | +type CustomerInput struct { | |
| 57 | + OrgID int64 | |
| 58 | + OrgSlug string | |
| 59 | + OrgName string | |
| 60 | + Email string | |
| 61 | +} | |
| 62 | + | |
| 63 | +type Customer struct { | |
| 64 | + ID string | |
| 65 | +} | |
| 66 | + | |
| 67 | +type CheckoutInput struct { | |
| 68 | + OrgID int64 | |
| 69 | + OrgSlug string | |
| 70 | + CustomerID string | |
| 71 | + SeatCount int64 | |
| 72 | + SuccessURL string | |
| 73 | + CancelURL string | |
| 74 | +} | |
| 75 | + | |
| 76 | +type CheckoutSession struct { | |
| 77 | + ID string | |
| 78 | + URL string | |
| 79 | +} | |
| 80 | + | |
| 81 | +type PortalInput struct { | |
| 82 | + CustomerID string | |
| 83 | + ReturnURL string | |
| 84 | +} | |
| 85 | + | |
| 86 | +type PortalSession struct { | |
| 87 | + ID string | |
| 88 | + URL string | |
| 89 | +} | |
| 90 | + | |
| 91 | +type SeatQuantityInput struct { | |
| 92 | + OrgID int64 | |
| 93 | + SubscriptionItemID string | |
| 94 | + Quantity int64 | |
| 95 | +} | |
| 96 | + | |
| 97 | +func New(cfg Config) (*Client, error) { | |
| 98 | + cfg.SecretKey = strings.TrimSpace(cfg.SecretKey) | |
| 99 | + cfg.WebhookSecret = strings.TrimSpace(cfg.WebhookSecret) | |
| 100 | + cfg.TeamPriceID = strings.TrimSpace(cfg.TeamPriceID) | |
| 101 | + if cfg.SecretKey == "" { | |
| 102 | + return nil, ErrSecretKeyRequired | |
| 103 | + } | |
| 104 | + if cfg.WebhookSecret == "" { | |
| 105 | + return nil, ErrWebhookSecretRequired | |
| 106 | + } | |
| 107 | + if cfg.TeamPriceID == "" { | |
| 108 | + return nil, ErrTeamPriceRequired | |
| 109 | + } | |
| 110 | + return &Client{ | |
| 111 | + stripe: stripeapi.NewClient(cfg.SecretKey), | |
| 112 | + webhookSecret: cfg.WebhookSecret, | |
| 113 | + teamPriceID: cfg.TeamPriceID, | |
| 114 | + automaticTax: cfg.AutomaticTax, | |
| 115 | + }, nil | |
| 116 | +} | |
| 117 | + | |
| 118 | +func (c *Client) CreateCustomer(ctx context.Context, in CustomerInput) (Customer, error) { | |
| 119 | + name := strings.TrimSpace(in.OrgName) | |
| 120 | + if name == "" { | |
| 121 | + name = strings.TrimSpace(in.OrgSlug) | |
| 122 | + } | |
| 123 | + params := &stripeapi.CustomerCreateParams{ | |
| 124 | + Name: stripeapi.String(name), | |
| 125 | + Description: stripeapi.String(fmt.Sprintf("shithub organization %s", strings.TrimSpace(in.OrgSlug))), | |
| 126 | + Metadata: orgMetadata(in.OrgID, in.OrgSlug), | |
| 127 | + } | |
| 128 | + if email := strings.TrimSpace(in.Email); email != "" { | |
| 129 | + params.Email = stripeapi.String(email) | |
| 130 | + } | |
| 131 | + params.SetIdempotencyKey(idempotencyKey("customer", in.OrgID, "v1")) | |
| 132 | + customer, err := c.stripe.V1Customers.Create(ctx, params) | |
| 133 | + if err != nil { | |
| 134 | + return Customer{}, err | |
| 135 | + } | |
| 136 | + return Customer{ID: customer.ID}, nil | |
| 137 | +} | |
| 138 | + | |
| 139 | +func (c *Client) CreateCheckoutSession(ctx context.Context, in CheckoutInput) (CheckoutSession, error) { | |
| 140 | + in.CustomerID = strings.TrimSpace(in.CustomerID) | |
| 141 | + if in.CustomerID == "" { | |
| 142 | + return CheckoutSession{}, ErrCustomerIDRequired | |
| 143 | + } | |
| 144 | + in.SuccessURL = strings.TrimSpace(in.SuccessURL) | |
| 145 | + if in.SuccessURL == "" { | |
| 146 | + return CheckoutSession{}, fmt.Errorf("%w: success_url", ErrURLRequired) | |
| 147 | + } | |
| 148 | + in.CancelURL = strings.TrimSpace(in.CancelURL) | |
| 149 | + if in.CancelURL == "" { | |
| 150 | + return CheckoutSession{}, fmt.Errorf("%w: cancel_url", ErrURLRequired) | |
| 151 | + } | |
| 152 | + if in.SeatCount < 1 { | |
| 153 | + in.SeatCount = 1 | |
| 154 | + } | |
| 155 | + metadata := orgMetadata(in.OrgID, in.OrgSlug) | |
| 156 | + mode := string(stripeapi.CheckoutSessionModeSubscription) | |
| 157 | + paymentMethodCollection := string(stripeapi.CheckoutSessionPaymentMethodCollectionAlways) | |
| 158 | + billingAddressCollection := string(stripeapi.CheckoutSessionBillingAddressCollectionAuto) | |
| 159 | + params := &stripeapi.CheckoutSessionCreateParams{ | |
| 160 | + Mode: stripeapi.String(mode), | |
| 161 | + Customer: stripeapi.String(in.CustomerID), | |
| 162 | + ClientReferenceID: stripeapi.String(strconv.FormatInt(in.OrgID, 10)), | |
| 163 | + SuccessURL: stripeapi.String(in.SuccessURL), | |
| 164 | + CancelURL: stripeapi.String(in.CancelURL), | |
| 165 | + PaymentMethodCollection: stripeapi.String(paymentMethodCollection), | |
| 166 | + BillingAddressCollection: stripeapi.String(billingAddressCollection), | |
| 167 | + LineItems: []*stripeapi.CheckoutSessionCreateLineItemParams{{ | |
| 168 | + Price: stripeapi.String(c.teamPriceID), | |
| 169 | + Quantity: stripeapi.Int64(in.SeatCount), | |
| 170 | + }}, | |
| 171 | + Metadata: metadata, | |
| 172 | + SubscriptionData: &stripeapi.CheckoutSessionCreateSubscriptionDataParams{ | |
| 173 | + Metadata: metadata, | |
| 174 | + }, | |
| 175 | + } | |
| 176 | + if c.automaticTax { | |
| 177 | + params.AutomaticTax = &stripeapi.CheckoutSessionCreateAutomaticTaxParams{ | |
| 178 | + Enabled: stripeapi.Bool(true), | |
| 179 | + } | |
| 180 | + } | |
| 181 | + params.SetIdempotencyKey(idempotencyKey("checkout", in.OrgID, "team", strconv.FormatInt(in.SeatCount, 10))) | |
| 182 | + session, err := c.stripe.V1CheckoutSessions.Create(ctx, params) | |
| 183 | + if err != nil { | |
| 184 | + return CheckoutSession{}, err | |
| 185 | + } | |
| 186 | + return CheckoutSession{ID: session.ID, URL: session.URL}, nil | |
| 187 | +} | |
| 188 | + | |
| 189 | +func (c *Client) CreatePortalSession(ctx context.Context, in PortalInput) (PortalSession, error) { | |
| 190 | + in.CustomerID = strings.TrimSpace(in.CustomerID) | |
| 191 | + if in.CustomerID == "" { | |
| 192 | + return PortalSession{}, ErrCustomerIDRequired | |
| 193 | + } | |
| 194 | + in.ReturnURL = strings.TrimSpace(in.ReturnURL) | |
| 195 | + if in.ReturnURL == "" { | |
| 196 | + return PortalSession{}, fmt.Errorf("%w: portal_return_url", ErrURLRequired) | |
| 197 | + } | |
| 198 | + params := &stripeapi.BillingPortalSessionCreateParams{ | |
| 199 | + Customer: stripeapi.String(in.CustomerID), | |
| 200 | + ReturnURL: stripeapi.String(in.ReturnURL), | |
| 201 | + } | |
| 202 | + session, err := c.stripe.V1BillingPortalSessions.Create(ctx, params) | |
| 203 | + if err != nil { | |
| 204 | + return PortalSession{}, err | |
| 205 | + } | |
| 206 | + return PortalSession{ID: session.ID, URL: session.URL}, nil | |
| 207 | +} | |
| 208 | + | |
| 209 | +func (c *Client) UpdateSubscriptionItemQuantity(ctx context.Context, in SeatQuantityInput) error { | |
| 210 | + in.SubscriptionItemID = strings.TrimSpace(in.SubscriptionItemID) | |
| 211 | + if in.SubscriptionItemID == "" { | |
| 212 | + return ErrSubscriptionItemID | |
| 213 | + } | |
| 214 | + if in.Quantity < 1 { | |
| 215 | + in.Quantity = 1 | |
| 216 | + } | |
| 217 | + params := &stripeapi.SubscriptionItemUpdateParams{ | |
| 218 | + Quantity: stripeapi.Int64(in.Quantity), | |
| 219 | + } | |
| 220 | + params.SetIdempotencyKey(idempotencyKey("seat-sync", in.OrgID, in.SubscriptionItemID, strconv.FormatInt(in.Quantity, 10))) | |
| 221 | + _, err := c.stripe.V1SubscriptionItems.Update(ctx, in.SubscriptionItemID, params) | |
| 222 | + return err | |
| 223 | +} | |
| 224 | + | |
| 225 | +func (c *Client) VerifyWebhook(payload []byte, signatureHeader string) (stripeapi.Event, error) { | |
| 226 | + return webhook.ConstructEvent(payload, signatureHeader, c.webhookSecret) | |
| 227 | +} | |
| 228 | + | |
| 229 | +func orgMetadata(orgID int64, orgSlug string) map[string]string { | |
| 230 | + return map[string]string{ | |
| 231 | + MetadataOrgID: strconv.FormatInt(orgID, 10), | |
| 232 | + MetadataOrgSlug: strings.TrimSpace(orgSlug), | |
| 233 | + } | |
| 234 | +} | |
| 235 | + | |
| 236 | +func idempotencyKey(parts ...any) string { | |
| 237 | + var b strings.Builder | |
| 238 | + b.WriteString("shithub") | |
| 239 | + for _, part := range parts { | |
| 240 | + b.WriteByte(':') | |
| 241 | + b.WriteString(strings.NewReplacer(":", "_", " ", "_", "/", "_").Replace(fmt.Sprint(part))) | |
| 242 | + } | |
| 243 | + return b.String() | |
| 244 | +} | |
internal/billing/stripebilling/client_test.goadded@@ -0,0 +1,52 @@ | ||
| 1 | +// SPDX-License-Identifier: AGPL-3.0-or-later | |
| 2 | + | |
| 3 | +package stripebilling | |
| 4 | + | |
| 5 | +import ( | |
| 6 | + "errors" | |
| 7 | + "fmt" | |
| 8 | + "testing" | |
| 9 | + | |
| 10 | + stripeapi "github.com/stripe/stripe-go/v85" | |
| 11 | + "github.com/stripe/stripe-go/v85/webhook" | |
| 12 | +) | |
| 13 | + | |
| 14 | +func TestNewValidatesRequiredConfig(t *testing.T) { | |
| 15 | + t.Parallel() | |
| 16 | + if _, err := New(Config{}); !errors.Is(err, ErrSecretKeyRequired) { | |
| 17 | + t.Fatalf("New without secret key: got %v", err) | |
| 18 | + } | |
| 19 | + if _, err := New(Config{SecretKey: "sk_test_123"}); !errors.Is(err, ErrWebhookSecretRequired) { | |
| 20 | + t.Fatalf("New without webhook secret: got %v", err) | |
| 21 | + } | |
| 22 | + if _, err := New(Config{SecretKey: "sk_test_123", WebhookSecret: "whsec_123"}); !errors.Is(err, ErrTeamPriceRequired) { | |
| 23 | + t.Fatalf("New without price id: got %v", err) | |
| 24 | + } | |
| 25 | +} | |
| 26 | + | |
| 27 | +func TestVerifyWebhookUsesSigningSecret(t *testing.T) { | |
| 28 | + t.Parallel() | |
| 29 | + client, err := New(Config{ | |
| 30 | + SecretKey: "sk_test_123", | |
| 31 | + WebhookSecret: "whsec_test", | |
| 32 | + TeamPriceID: "price_123", | |
| 33 | + }) | |
| 34 | + if err != nil { | |
| 35 | + t.Fatalf("New: %v", err) | |
| 36 | + } | |
| 37 | + payload := []byte(fmt.Sprintf(`{"id":"evt_test","object":"event","api_version":%q,"type":"customer.subscription.updated","data":{"object":{"id":"sub_test","object":"subscription"}}}`, stripeapi.APIVersion)) | |
| 38 | + signed := webhook.GenerateTestSignedPayload(&webhook.UnsignedPayload{ | |
| 39 | + Payload: payload, | |
| 40 | + Secret: "whsec_test", | |
| 41 | + }) | |
| 42 | + event, err := client.VerifyWebhook(payload, signed.Header) | |
| 43 | + if err != nil { | |
| 44 | + t.Fatalf("VerifyWebhook: %v", err) | |
| 45 | + } | |
| 46 | + if event.ID != "evt_test" || event.Type != "customer.subscription.updated" { | |
| 47 | + t.Fatalf("unexpected event: id=%s type=%s", event.ID, event.Type) | |
| 48 | + } | |
| 49 | + if _, err := client.VerifyWebhook(payload, "t=1,v1=bad"); err == nil { | |
| 50 | + t.Fatalf("VerifyWebhook accepted bad signature") | |
| 51 | + } | |
| 52 | +} | |