Go · 17010 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 // Package billing owns local paid-organization state. It stores Stripe
4 // identifiers and derived subscription state, but it does not call
5 // Stripe directly; Stripe API details stay in the SP03 adapter layer.
6 package billing
7
8 import (
9 "context"
10 "encoding/json"
11 "errors"
12 "fmt"
13 "strings"
14 "time"
15
16 "github.com/jackc/pgx/v5"
17 "github.com/jackc/pgx/v5/pgtype"
18 "github.com/jackc/pgx/v5/pgxpool"
19
20 billingdb "github.com/tenseleyFlow/shithub/internal/billing/sqlc"
21 )
22
23 type Deps struct {
24 Pool *pgxpool.Pool
25 }
26
27 type (
28 Plan = billingdb.OrgPlan
29 SubscriptionStatus = billingdb.BillingSubscriptionStatus
30 InvoiceStatus = billingdb.BillingInvoiceStatus
31 State = billingdb.OrgBillingState
32 )
33
34 const (
35 PlanFree = billingdb.OrgPlanFree
36 PlanTeam = billingdb.OrgPlanTeam
37 PlanEnterprise = billingdb.OrgPlanEnterprise
38
39 SubscriptionStatusNone = billingdb.BillingSubscriptionStatusNone
40 SubscriptionStatusIncomplete = billingdb.BillingSubscriptionStatusIncomplete
41 SubscriptionStatusTrialing = billingdb.BillingSubscriptionStatusTrialing
42 SubscriptionStatusActive = billingdb.BillingSubscriptionStatusActive
43 SubscriptionStatusPastDue = billingdb.BillingSubscriptionStatusPastDue
44 SubscriptionStatusCanceled = billingdb.BillingSubscriptionStatusCanceled
45 SubscriptionStatusUnpaid = billingdb.BillingSubscriptionStatusUnpaid
46 SubscriptionStatusPaused = billingdb.BillingSubscriptionStatusPaused
47
48 InvoiceStatusDraft = billingdb.BillingInvoiceStatusDraft
49 InvoiceStatusOpen = billingdb.BillingInvoiceStatusOpen
50 InvoiceStatusPaid = billingdb.BillingInvoiceStatusPaid
51 InvoiceStatusVoid = billingdb.BillingInvoiceStatusVoid
52 InvoiceStatusUncollectible = billingdb.BillingInvoiceStatusUncollectible
53 )
54
55 var (
56 ErrPoolRequired = errors.New("billing: pool is required")
57 ErrOrgIDRequired = errors.New("billing: org id is required")
58 ErrStripeCustomerID = errors.New("billing: stripe customer id is required")
59 ErrStripeSubscriptionID = errors.New("billing: stripe subscription id is required")
60 ErrStripeInvoiceID = errors.New("billing: stripe invoice id is required")
61 ErrInvalidPlan = errors.New("billing: invalid plan")
62 ErrInvalidStatus = errors.New("billing: invalid subscription status")
63 ErrInvalidInvoiceStatus = errors.New("billing: invalid invoice status")
64 ErrInvalidSeatCount = errors.New("billing: seat counts cannot be negative")
65 ErrWebhookEventID = errors.New("billing: webhook event id is required")
66 ErrWebhookEventType = errors.New("billing: webhook event type is required")
67 ErrWebhookPayload = errors.New("billing: webhook payload must be a JSON object")
68 )
69
70 // SubscriptionSnapshot is the local projection of a provider
71 // subscription event. Provider-specific conversion belongs in SP03.
72 type SubscriptionSnapshot struct {
73 OrgID int64
74 Plan Plan
75 Status SubscriptionStatus
76 StripeSubscriptionID string
77 StripeSubscriptionItemID string
78 CurrentPeriodStart time.Time
79 CurrentPeriodEnd time.Time
80 CancelAtPeriodEnd bool
81 TrialEnd time.Time
82 CanceledAt time.Time
83 LastWebhookEventID string
84 }
85
86 type SeatSnapshot struct {
87 OrgID int64
88 StripeSubscriptionID string
89 ActiveMembers int
90 BillableSeats int
91 Source string
92 }
93
94 type WebhookEvent struct {
95 ProviderEventID string
96 EventType string
97 APIVersion string
98 Payload []byte
99 }
100
101 type InvoiceSnapshot struct {
102 OrgID int64
103 StripeInvoiceID string
104 StripeCustomerID string
105 StripeSubscriptionID string
106 Status InvoiceStatus
107 Number string
108 Currency string
109 AmountDueCents int64
110 AmountPaidCents int64
111 AmountRemainingCents int64
112 HostedInvoiceURL string
113 InvoicePDFURL string
114 PeriodStart time.Time
115 PeriodEnd time.Time
116 DueAt time.Time
117 PaidAt time.Time
118 VoidedAt time.Time
119 }
120
121 func GetOrgBillingState(ctx context.Context, deps Deps, orgID int64) (State, error) {
122 if err := validateDeps(deps); err != nil {
123 return State{}, err
124 }
125 if orgID == 0 {
126 return State{}, ErrOrgIDRequired
127 }
128 return billingdb.New().GetOrgBillingState(ctx, deps.Pool, orgID)
129 }
130
131 func GetOrgBillingStateByStripeCustomer(ctx context.Context, deps Deps, customerID string) (State, error) {
132 if err := validateDeps(deps); err != nil {
133 return State{}, err
134 }
135 customerID = strings.TrimSpace(customerID)
136 if customerID == "" {
137 return State{}, ErrStripeCustomerID
138 }
139 return billingdb.New().GetOrgBillingStateByStripeCustomer(ctx, deps.Pool, pgText(customerID))
140 }
141
142 func GetOrgBillingStateByStripeSubscription(ctx context.Context, deps Deps, subscriptionID string) (State, error) {
143 if err := validateDeps(deps); err != nil {
144 return State{}, err
145 }
146 subscriptionID = strings.TrimSpace(subscriptionID)
147 if subscriptionID == "" {
148 return State{}, ErrStripeSubscriptionID
149 }
150 return billingdb.New().GetOrgBillingStateByStripeSubscription(ctx, deps.Pool, pgText(subscriptionID))
151 }
152
153 func SetStripeCustomer(ctx context.Context, deps Deps, orgID int64, customerID string) (State, error) {
154 if err := validateDeps(deps); err != nil {
155 return State{}, err
156 }
157 if orgID == 0 {
158 return State{}, ErrOrgIDRequired
159 }
160 customerID = strings.TrimSpace(customerID)
161 if customerID == "" {
162 return State{}, ErrStripeCustomerID
163 }
164 return billingdb.New().SetStripeCustomer(ctx, deps.Pool, billingdb.SetStripeCustomerParams{
165 OrgID: orgID,
166 StripeCustomerID: pgText(customerID),
167 })
168 }
169
170 func ApplySubscriptionSnapshot(ctx context.Context, deps Deps, snap SubscriptionSnapshot) (State, error) {
171 if err := validateDeps(deps); err != nil {
172 return State{}, err
173 }
174 if snap.OrgID == 0 {
175 return State{}, ErrOrgIDRequired
176 }
177 if !validPlan(snap.Plan) {
178 return State{}, fmt.Errorf("%w: %q", ErrInvalidPlan, snap.Plan)
179 }
180 if !validStatus(snap.Status) {
181 return State{}, fmt.Errorf("%w: %q", ErrInvalidStatus, snap.Status)
182 }
183 row, err := billingdb.New().ApplySubscriptionSnapshot(ctx, deps.Pool, billingdb.ApplySubscriptionSnapshotParams{
184 OrgID: snap.OrgID,
185 Plan: snap.Plan,
186 SubscriptionStatus: snap.Status,
187 StripeSubscriptionID: pgText(snap.StripeSubscriptionID),
188 StripeSubscriptionItemID: pgText(snap.StripeSubscriptionItemID),
189 CurrentPeriodStart: pgTime(snap.CurrentPeriodStart),
190 CurrentPeriodEnd: pgTime(snap.CurrentPeriodEnd),
191 CancelAtPeriodEnd: snap.CancelAtPeriodEnd,
192 TrialEnd: pgTime(snap.TrialEnd),
193 CanceledAt: pgTime(snap.CanceledAt),
194 LastWebhookEventID: strings.TrimSpace(snap.LastWebhookEventID),
195 })
196 if err != nil {
197 return State{}, err
198 }
199 return stateFromApply(row), nil
200 }
201
202 func RecordWebhookEvent(ctx context.Context, deps Deps, event WebhookEvent) (billingdb.BillingWebhookEvent, bool, error) {
203 if err := validateDeps(deps); err != nil {
204 return billingdb.BillingWebhookEvent{}, false, err
205 }
206 event.ProviderEventID = strings.TrimSpace(event.ProviderEventID)
207 event.EventType = strings.TrimSpace(event.EventType)
208 event.APIVersion = strings.TrimSpace(event.APIVersion)
209 if event.ProviderEventID == "" {
210 return billingdb.BillingWebhookEvent{}, false, ErrWebhookEventID
211 }
212 if event.EventType == "" {
213 return billingdb.BillingWebhookEvent{}, false, ErrWebhookEventType
214 }
215 if !jsonObject(event.Payload) {
216 return billingdb.BillingWebhookEvent{}, false, ErrWebhookPayload
217 }
218 row, err := billingdb.New().CreateWebhookEventReceipt(ctx, deps.Pool, billingdb.CreateWebhookEventReceiptParams{
219 ProviderEventID: event.ProviderEventID,
220 EventType: event.EventType,
221 ApiVersion: event.APIVersion,
222 Payload: event.Payload,
223 })
224 if err != nil {
225 if errors.Is(err, pgx.ErrNoRows) {
226 row, err = billingdb.New().GetWebhookEventReceipt(ctx, deps.Pool, event.ProviderEventID)
227 if err != nil {
228 return billingdb.BillingWebhookEvent{}, false, err
229 }
230 return row, false, nil
231 }
232 return billingdb.BillingWebhookEvent{}, false, err
233 }
234 return row, true, nil
235 }
236
237 func MarkWebhookEventProcessed(ctx context.Context, deps Deps, providerEventID string) (billingdb.BillingWebhookEvent, error) {
238 if err := validateDeps(deps); err != nil {
239 return billingdb.BillingWebhookEvent{}, err
240 }
241 providerEventID = strings.TrimSpace(providerEventID)
242 if providerEventID == "" {
243 return billingdb.BillingWebhookEvent{}, ErrWebhookEventID
244 }
245 return billingdb.New().MarkWebhookEventProcessed(ctx, deps.Pool, providerEventID)
246 }
247
248 func MarkWebhookEventFailed(ctx context.Context, deps Deps, providerEventID, processError string) (billingdb.BillingWebhookEvent, error) {
249 if err := validateDeps(deps); err != nil {
250 return billingdb.BillingWebhookEvent{}, err
251 }
252 providerEventID = strings.TrimSpace(providerEventID)
253 if providerEventID == "" {
254 return billingdb.BillingWebhookEvent{}, ErrWebhookEventID
255 }
256 processError = strings.TrimSpace(processError)
257 if len(processError) > 2000 {
258 processError = processError[:2000]
259 }
260 return billingdb.New().MarkWebhookEventFailed(ctx, deps.Pool, billingdb.MarkWebhookEventFailedParams{
261 ProviderEventID: providerEventID,
262 ProcessError: processError,
263 })
264 }
265
266 func GetWebhookEventReceipt(ctx context.Context, deps Deps, providerEventID string) (billingdb.BillingWebhookEvent, error) {
267 if err := validateDeps(deps); err != nil {
268 return billingdb.BillingWebhookEvent{}, err
269 }
270 providerEventID = strings.TrimSpace(providerEventID)
271 if providerEventID == "" {
272 return billingdb.BillingWebhookEvent{}, ErrWebhookEventID
273 }
274 return billingdb.New().GetWebhookEventReceipt(ctx, deps.Pool, providerEventID)
275 }
276
277 func UpsertInvoice(ctx context.Context, deps Deps, snap InvoiceSnapshot) (billingdb.BillingInvoice, error) {
278 if err := validateDeps(deps); err != nil {
279 return billingdb.BillingInvoice{}, err
280 }
281 if snap.OrgID == 0 {
282 return billingdb.BillingInvoice{}, ErrOrgIDRequired
283 }
284 snap.StripeInvoiceID = strings.TrimSpace(snap.StripeInvoiceID)
285 if snap.StripeInvoiceID == "" {
286 return billingdb.BillingInvoice{}, ErrStripeInvoiceID
287 }
288 snap.StripeCustomerID = strings.TrimSpace(snap.StripeCustomerID)
289 if snap.StripeCustomerID == "" {
290 return billingdb.BillingInvoice{}, ErrStripeCustomerID
291 }
292 if !validInvoiceStatus(snap.Status) {
293 return billingdb.BillingInvoice{}, fmt.Errorf("%w: %q", ErrInvalidInvoiceStatus, snap.Status)
294 }
295 row, err := billingdb.New().UpsertInvoice(ctx, deps.Pool, billingdb.UpsertInvoiceParams{
296 OrgID: snap.OrgID,
297 StripeInvoiceID: snap.StripeInvoiceID,
298 StripeCustomerID: snap.StripeCustomerID,
299 StripeSubscriptionID: pgText(snap.StripeSubscriptionID),
300 Status: snap.Status,
301 Number: strings.TrimSpace(snap.Number),
302 Currency: strings.ToLower(strings.TrimSpace(snap.Currency)),
303 AmountDueCents: snap.AmountDueCents,
304 AmountPaidCents: snap.AmountPaidCents,
305 AmountRemainingCents: snap.AmountRemainingCents,
306 HostedInvoiceUrl: strings.TrimSpace(snap.HostedInvoiceURL),
307 InvoicePdfUrl: strings.TrimSpace(snap.InvoicePDFURL),
308 PeriodStart: pgTime(snap.PeriodStart),
309 PeriodEnd: pgTime(snap.PeriodEnd),
310 DueAt: pgTime(snap.DueAt),
311 PaidAt: pgTime(snap.PaidAt),
312 VoidedAt: pgTime(snap.VoidedAt),
313 })
314 if err != nil {
315 return billingdb.BillingInvoice{}, err
316 }
317 return row, nil
318 }
319
320 func ListInvoicesForOrg(ctx context.Context, deps Deps, orgID int64, limit int32) ([]billingdb.BillingInvoice, error) {
321 if err := validateDeps(deps); err != nil {
322 return nil, err
323 }
324 if orgID == 0 {
325 return nil, ErrOrgIDRequired
326 }
327 if limit <= 0 {
328 limit = 10
329 }
330 return billingdb.New().ListInvoicesForOrg(ctx, deps.Pool, billingdb.ListInvoicesForOrgParams{
331 OrgID: orgID,
332 Limit: limit,
333 })
334 }
335
336 func SyncSeatSnapshot(ctx context.Context, deps Deps, snap SeatSnapshot) (billingdb.BillingSeatSnapshot, error) {
337 if err := validateDeps(deps); err != nil {
338 return billingdb.BillingSeatSnapshot{}, err
339 }
340 if snap.OrgID == 0 {
341 return billingdb.BillingSeatSnapshot{}, ErrOrgIDRequired
342 }
343 if snap.ActiveMembers < 0 || snap.BillableSeats < 0 {
344 return billingdb.BillingSeatSnapshot{}, ErrInvalidSeatCount
345 }
346 source := strings.TrimSpace(snap.Source)
347 if source == "" {
348 source = "local"
349 }
350 row, err := billingdb.New().CreateSeatSnapshot(ctx, deps.Pool, billingdb.CreateSeatSnapshotParams{
351 OrgID: snap.OrgID,
352 StripeSubscriptionID: pgText(snap.StripeSubscriptionID),
353 ActiveMembers: int32(snap.ActiveMembers),
354 BillableSeats: int32(snap.BillableSeats),
355 Source: source,
356 })
357 if err != nil {
358 return billingdb.BillingSeatSnapshot{}, err
359 }
360 return billingdb.BillingSeatSnapshot(row), nil
361 }
362
363 func CountBillableOrgMembers(ctx context.Context, deps Deps, orgID int64) (int, error) {
364 if err := validateDeps(deps); err != nil {
365 return 0, err
366 }
367 if orgID == 0 {
368 return 0, ErrOrgIDRequired
369 }
370 n, err := billingdb.New().CountBillableOrgMembers(ctx, deps.Pool, orgID)
371 if err != nil {
372 return 0, err
373 }
374 return int(n), nil
375 }
376
377 func CountPendingOrgInvitations(ctx context.Context, deps Deps, orgID int64) (int, error) {
378 if err := validateDeps(deps); err != nil {
379 return 0, err
380 }
381 if orgID == 0 {
382 return 0, ErrOrgIDRequired
383 }
384 n, err := billingdb.New().CountPendingOrgInvitations(ctx, deps.Pool, orgID)
385 if err != nil {
386 return 0, err
387 }
388 return int(n), nil
389 }
390
391 func MarkPastDue(ctx context.Context, deps Deps, orgID int64, graceUntil time.Time, lastWebhookEventID string) (State, error) {
392 if err := validateDeps(deps); err != nil {
393 return State{}, err
394 }
395 if orgID == 0 {
396 return State{}, ErrOrgIDRequired
397 }
398 return billingdb.New().MarkPastDue(ctx, deps.Pool, billingdb.MarkPastDueParams{
399 OrgID: orgID,
400 GraceUntil: pgTime(graceUntil),
401 LastWebhookEventID: strings.TrimSpace(lastWebhookEventID),
402 })
403 }
404
405 func MarkPaymentSucceeded(ctx context.Context, deps Deps, orgID int64, lastWebhookEventID string) (State, error) {
406 if err := validateDeps(deps); err != nil {
407 return State{}, err
408 }
409 if orgID == 0 {
410 return State{}, ErrOrgIDRequired
411 }
412 row, err := billingdb.New().MarkPaymentSucceeded(ctx, deps.Pool, billingdb.MarkPaymentSucceededParams{
413 OrgID: orgID,
414 LastWebhookEventID: strings.TrimSpace(lastWebhookEventID),
415 })
416 if err != nil {
417 return State{}, err
418 }
419 return stateFromPaymentSucceeded(row), nil
420 }
421
422 func MarkCanceled(ctx context.Context, deps Deps, orgID int64, lastWebhookEventID string) (State, error) {
423 if err := validateDeps(deps); err != nil {
424 return State{}, err
425 }
426 if orgID == 0 {
427 return State{}, ErrOrgIDRequired
428 }
429 row, err := billingdb.New().MarkCanceled(ctx, deps.Pool, billingdb.MarkCanceledParams{
430 OrgID: orgID,
431 LastWebhookEventID: strings.TrimSpace(lastWebhookEventID),
432 })
433 if err != nil {
434 return State{}, err
435 }
436 return stateFromCanceled(row), nil
437 }
438
439 func ClearBillingLock(ctx context.Context, deps Deps, orgID int64) (State, error) {
440 if err := validateDeps(deps); err != nil {
441 return State{}, err
442 }
443 if orgID == 0 {
444 return State{}, ErrOrgIDRequired
445 }
446 row, err := billingdb.New().ClearBillingLock(ctx, deps.Pool, orgID)
447 if err != nil {
448 return State{}, err
449 }
450 return stateFromClear(row), nil
451 }
452
453 func validateDeps(deps Deps) error {
454 if deps.Pool == nil {
455 return ErrPoolRequired
456 }
457 return nil
458 }
459
460 func validPlan(plan Plan) bool {
461 switch plan {
462 case PlanFree, PlanTeam, PlanEnterprise:
463 return true
464 default:
465 return false
466 }
467 }
468
469 func validStatus(status SubscriptionStatus) bool {
470 switch status {
471 case SubscriptionStatusNone,
472 SubscriptionStatusIncomplete,
473 SubscriptionStatusTrialing,
474 SubscriptionStatusActive,
475 SubscriptionStatusPastDue,
476 SubscriptionStatusCanceled,
477 SubscriptionStatusUnpaid,
478 SubscriptionStatusPaused:
479 return true
480 default:
481 return false
482 }
483 }
484
485 func validInvoiceStatus(status InvoiceStatus) bool {
486 switch status {
487 case InvoiceStatusDraft,
488 InvoiceStatusOpen,
489 InvoiceStatusPaid,
490 InvoiceStatusVoid,
491 InvoiceStatusUncollectible:
492 return true
493 default:
494 return false
495 }
496 }
497
498 func pgText(s string) pgtype.Text {
499 s = strings.TrimSpace(s)
500 return pgtype.Text{String: s, Valid: s != ""}
501 }
502
503 func pgTime(t time.Time) pgtype.Timestamptz {
504 return pgtype.Timestamptz{Time: t, Valid: !t.IsZero()}
505 }
506
507 func jsonObject(payload []byte) bool {
508 var v map[string]any
509 return json.Unmarshal(payload, &v) == nil && v != nil
510 }
511
512 func stateFromApply(row billingdb.ApplySubscriptionSnapshotRow) State {
513 return State(row)
514 }
515
516 func stateFromCanceled(row billingdb.MarkCanceledRow) State {
517 return State(row)
518 }
519
520 func stateFromPaymentSucceeded(row billingdb.MarkPaymentSucceededRow) State {
521 return State(row)
522 }
523
524 func stateFromClear(row billingdb.ClearBillingLockRow) State {
525 return State(row)
526 }
527