Go · 15052 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 // Package billing owns local paid-organization state. It stores Stripe
4 // identifiers and derived subscription state, but it does not call
5 // Stripe directly; Stripe API details stay in the SP03 adapter layer.
6 package billing
7
8 import (
9 "context"
10 "encoding/json"
11 "errors"
12 "fmt"
13 "strings"
14 "time"
15
16 "github.com/jackc/pgx/v5"
17 "github.com/jackc/pgx/v5/pgtype"
18 "github.com/jackc/pgx/v5/pgxpool"
19
20 billingdb "github.com/tenseleyFlow/shithub/internal/billing/sqlc"
21 )
22
23 type Deps struct {
24 Pool *pgxpool.Pool
25 }
26
27 type (
28 Plan = billingdb.OrgPlan
29 SubscriptionStatus = billingdb.BillingSubscriptionStatus
30 InvoiceStatus = billingdb.BillingInvoiceStatus
31 State = billingdb.OrgBillingState
32 )
33
34 const (
35 PlanFree = billingdb.OrgPlanFree
36 PlanTeam = billingdb.OrgPlanTeam
37 PlanEnterprise = billingdb.OrgPlanEnterprise
38
39 SubscriptionStatusNone = billingdb.BillingSubscriptionStatusNone
40 SubscriptionStatusIncomplete = billingdb.BillingSubscriptionStatusIncomplete
41 SubscriptionStatusTrialing = billingdb.BillingSubscriptionStatusTrialing
42 SubscriptionStatusActive = billingdb.BillingSubscriptionStatusActive
43 SubscriptionStatusPastDue = billingdb.BillingSubscriptionStatusPastDue
44 SubscriptionStatusCanceled = billingdb.BillingSubscriptionStatusCanceled
45 SubscriptionStatusUnpaid = billingdb.BillingSubscriptionStatusUnpaid
46 SubscriptionStatusPaused = billingdb.BillingSubscriptionStatusPaused
47
48 InvoiceStatusDraft = billingdb.BillingInvoiceStatusDraft
49 InvoiceStatusOpen = billingdb.BillingInvoiceStatusOpen
50 InvoiceStatusPaid = billingdb.BillingInvoiceStatusPaid
51 InvoiceStatusVoid = billingdb.BillingInvoiceStatusVoid
52 InvoiceStatusUncollectible = billingdb.BillingInvoiceStatusUncollectible
53 )
54
55 var (
56 ErrPoolRequired = errors.New("billing: pool is required")
57 ErrOrgIDRequired = errors.New("billing: org id is required")
58 ErrStripeCustomerID = errors.New("billing: stripe customer id is required")
59 ErrStripeSubscriptionID = errors.New("billing: stripe subscription id is required")
60 ErrStripeInvoiceID = errors.New("billing: stripe invoice id is required")
61 ErrInvalidPlan = errors.New("billing: invalid plan")
62 ErrInvalidStatus = errors.New("billing: invalid subscription status")
63 ErrInvalidInvoiceStatus = errors.New("billing: invalid invoice status")
64 ErrInvalidSeatCount = errors.New("billing: seat counts cannot be negative")
65 ErrWebhookEventID = errors.New("billing: webhook event id is required")
66 ErrWebhookEventType = errors.New("billing: webhook event type is required")
67 ErrWebhookPayload = errors.New("billing: webhook payload must be a JSON object")
68 )
69
70 // SubscriptionSnapshot is the local projection of a provider
71 // subscription event. Provider-specific conversion belongs in SP03.
72 type SubscriptionSnapshot struct {
73 OrgID int64
74 Plan Plan
75 Status SubscriptionStatus
76 StripeSubscriptionID string
77 StripeSubscriptionItemID string
78 CurrentPeriodStart time.Time
79 CurrentPeriodEnd time.Time
80 CancelAtPeriodEnd bool
81 TrialEnd time.Time
82 CanceledAt time.Time
83 LastWebhookEventID string
84 }
85
86 type SeatSnapshot struct {
87 OrgID int64
88 StripeSubscriptionID string
89 ActiveMembers int
90 BillableSeats int
91 Source string
92 }
93
94 type WebhookEvent struct {
95 ProviderEventID string
96 EventType string
97 APIVersion string
98 Payload []byte
99 }
100
101 type InvoiceSnapshot struct {
102 OrgID int64
103 StripeInvoiceID string
104 StripeCustomerID string
105 StripeSubscriptionID string
106 Status InvoiceStatus
107 Number string
108 Currency string
109 AmountDueCents int64
110 AmountPaidCents int64
111 AmountRemainingCents int64
112 HostedInvoiceURL string
113 InvoicePDFURL string
114 PeriodStart time.Time
115 PeriodEnd time.Time
116 DueAt time.Time
117 PaidAt time.Time
118 VoidedAt time.Time
119 }
120
121 func GetOrgBillingState(ctx context.Context, deps Deps, orgID int64) (State, error) {
122 if err := validateDeps(deps); err != nil {
123 return State{}, err
124 }
125 if orgID == 0 {
126 return State{}, ErrOrgIDRequired
127 }
128 return billingdb.New().GetOrgBillingState(ctx, deps.Pool, orgID)
129 }
130
131 func GetOrgBillingStateByStripeCustomer(ctx context.Context, deps Deps, customerID string) (State, error) {
132 if err := validateDeps(deps); err != nil {
133 return State{}, err
134 }
135 customerID = strings.TrimSpace(customerID)
136 if customerID == "" {
137 return State{}, ErrStripeCustomerID
138 }
139 return billingdb.New().GetOrgBillingStateByStripeCustomer(ctx, deps.Pool, pgText(customerID))
140 }
141
142 func GetOrgBillingStateByStripeSubscription(ctx context.Context, deps Deps, subscriptionID string) (State, error) {
143 if err := validateDeps(deps); err != nil {
144 return State{}, err
145 }
146 subscriptionID = strings.TrimSpace(subscriptionID)
147 if subscriptionID == "" {
148 return State{}, ErrStripeSubscriptionID
149 }
150 return billingdb.New().GetOrgBillingStateByStripeSubscription(ctx, deps.Pool, pgText(subscriptionID))
151 }
152
153 func SetStripeCustomer(ctx context.Context, deps Deps, orgID int64, customerID string) (State, error) {
154 if err := validateDeps(deps); err != nil {
155 return State{}, err
156 }
157 if orgID == 0 {
158 return State{}, ErrOrgIDRequired
159 }
160 customerID = strings.TrimSpace(customerID)
161 if customerID == "" {
162 return State{}, ErrStripeCustomerID
163 }
164 return billingdb.New().SetStripeCustomer(ctx, deps.Pool, billingdb.SetStripeCustomerParams{
165 OrgID: orgID,
166 StripeCustomerID: pgText(customerID),
167 })
168 }
169
170 func ApplySubscriptionSnapshot(ctx context.Context, deps Deps, snap SubscriptionSnapshot) (State, error) {
171 if err := validateDeps(deps); err != nil {
172 return State{}, err
173 }
174 if snap.OrgID == 0 {
175 return State{}, ErrOrgIDRequired
176 }
177 if !validPlan(snap.Plan) {
178 return State{}, fmt.Errorf("%w: %q", ErrInvalidPlan, snap.Plan)
179 }
180 if !validStatus(snap.Status) {
181 return State{}, fmt.Errorf("%w: %q", ErrInvalidStatus, snap.Status)
182 }
183 row, err := billingdb.New().ApplySubscriptionSnapshot(ctx, deps.Pool, billingdb.ApplySubscriptionSnapshotParams{
184 OrgID: snap.OrgID,
185 Plan: snap.Plan,
186 SubscriptionStatus: snap.Status,
187 StripeSubscriptionID: pgText(snap.StripeSubscriptionID),
188 StripeSubscriptionItemID: pgText(snap.StripeSubscriptionItemID),
189 CurrentPeriodStart: pgTime(snap.CurrentPeriodStart),
190 CurrentPeriodEnd: pgTime(snap.CurrentPeriodEnd),
191 CancelAtPeriodEnd: snap.CancelAtPeriodEnd,
192 TrialEnd: pgTime(snap.TrialEnd),
193 CanceledAt: pgTime(snap.CanceledAt),
194 LastWebhookEventID: strings.TrimSpace(snap.LastWebhookEventID),
195 })
196 if err != nil {
197 return State{}, err
198 }
199 return stateFromApply(row), nil
200 }
201
202 func RecordWebhookEvent(ctx context.Context, deps Deps, event WebhookEvent) (billingdb.BillingWebhookEvent, bool, error) {
203 if err := validateDeps(deps); err != nil {
204 return billingdb.BillingWebhookEvent{}, false, err
205 }
206 event.ProviderEventID = strings.TrimSpace(event.ProviderEventID)
207 event.EventType = strings.TrimSpace(event.EventType)
208 event.APIVersion = strings.TrimSpace(event.APIVersion)
209 if event.ProviderEventID == "" {
210 return billingdb.BillingWebhookEvent{}, false, ErrWebhookEventID
211 }
212 if event.EventType == "" {
213 return billingdb.BillingWebhookEvent{}, false, ErrWebhookEventType
214 }
215 if !jsonObject(event.Payload) {
216 return billingdb.BillingWebhookEvent{}, false, ErrWebhookPayload
217 }
218 row, err := billingdb.New().CreateWebhookEventReceipt(ctx, deps.Pool, billingdb.CreateWebhookEventReceiptParams{
219 ProviderEventID: event.ProviderEventID,
220 EventType: event.EventType,
221 ApiVersion: event.APIVersion,
222 Payload: event.Payload,
223 })
224 if err != nil {
225 if errors.Is(err, pgx.ErrNoRows) {
226 return billingdb.BillingWebhookEvent{}, false, nil
227 }
228 return billingdb.BillingWebhookEvent{}, false, err
229 }
230 return row, true, nil
231 }
232
233 func MarkWebhookEventProcessed(ctx context.Context, deps Deps, providerEventID string) (billingdb.BillingWebhookEvent, error) {
234 if err := validateDeps(deps); err != nil {
235 return billingdb.BillingWebhookEvent{}, err
236 }
237 providerEventID = strings.TrimSpace(providerEventID)
238 if providerEventID == "" {
239 return billingdb.BillingWebhookEvent{}, ErrWebhookEventID
240 }
241 return billingdb.New().MarkWebhookEventProcessed(ctx, deps.Pool, providerEventID)
242 }
243
244 func MarkWebhookEventFailed(ctx context.Context, deps Deps, providerEventID, processError string) (billingdb.BillingWebhookEvent, error) {
245 if err := validateDeps(deps); err != nil {
246 return billingdb.BillingWebhookEvent{}, err
247 }
248 providerEventID = strings.TrimSpace(providerEventID)
249 if providerEventID == "" {
250 return billingdb.BillingWebhookEvent{}, ErrWebhookEventID
251 }
252 processError = strings.TrimSpace(processError)
253 if len(processError) > 2000 {
254 processError = processError[:2000]
255 }
256 return billingdb.New().MarkWebhookEventFailed(ctx, deps.Pool, billingdb.MarkWebhookEventFailedParams{
257 ProviderEventID: providerEventID,
258 ProcessError: processError,
259 })
260 }
261
262 func UpsertInvoice(ctx context.Context, deps Deps, snap InvoiceSnapshot) (billingdb.BillingInvoice, error) {
263 if err := validateDeps(deps); err != nil {
264 return billingdb.BillingInvoice{}, err
265 }
266 if snap.OrgID == 0 {
267 return billingdb.BillingInvoice{}, ErrOrgIDRequired
268 }
269 snap.StripeInvoiceID = strings.TrimSpace(snap.StripeInvoiceID)
270 if snap.StripeInvoiceID == "" {
271 return billingdb.BillingInvoice{}, ErrStripeInvoiceID
272 }
273 snap.StripeCustomerID = strings.TrimSpace(snap.StripeCustomerID)
274 if snap.StripeCustomerID == "" {
275 return billingdb.BillingInvoice{}, ErrStripeCustomerID
276 }
277 if !validInvoiceStatus(snap.Status) {
278 return billingdb.BillingInvoice{}, fmt.Errorf("%w: %q", ErrInvalidInvoiceStatus, snap.Status)
279 }
280 row, err := billingdb.New().UpsertInvoice(ctx, deps.Pool, billingdb.UpsertInvoiceParams{
281 OrgID: snap.OrgID,
282 StripeInvoiceID: snap.StripeInvoiceID,
283 StripeCustomerID: snap.StripeCustomerID,
284 StripeSubscriptionID: pgText(snap.StripeSubscriptionID),
285 Status: snap.Status,
286 Number: strings.TrimSpace(snap.Number),
287 Currency: strings.ToLower(strings.TrimSpace(snap.Currency)),
288 AmountDueCents: snap.AmountDueCents,
289 AmountPaidCents: snap.AmountPaidCents,
290 AmountRemainingCents: snap.AmountRemainingCents,
291 HostedInvoiceUrl: strings.TrimSpace(snap.HostedInvoiceURL),
292 InvoicePdfUrl: strings.TrimSpace(snap.InvoicePDFURL),
293 PeriodStart: pgTime(snap.PeriodStart),
294 PeriodEnd: pgTime(snap.PeriodEnd),
295 DueAt: pgTime(snap.DueAt),
296 PaidAt: pgTime(snap.PaidAt),
297 VoidedAt: pgTime(snap.VoidedAt),
298 })
299 if err != nil {
300 return billingdb.BillingInvoice{}, err
301 }
302 return row, nil
303 }
304
305 func SyncSeatSnapshot(ctx context.Context, deps Deps, snap SeatSnapshot) (billingdb.BillingSeatSnapshot, error) {
306 if err := validateDeps(deps); err != nil {
307 return billingdb.BillingSeatSnapshot{}, err
308 }
309 if snap.OrgID == 0 {
310 return billingdb.BillingSeatSnapshot{}, ErrOrgIDRequired
311 }
312 if snap.ActiveMembers < 0 || snap.BillableSeats < 0 {
313 return billingdb.BillingSeatSnapshot{}, ErrInvalidSeatCount
314 }
315 source := strings.TrimSpace(snap.Source)
316 if source == "" {
317 source = "local"
318 }
319 row, err := billingdb.New().CreateSeatSnapshot(ctx, deps.Pool, billingdb.CreateSeatSnapshotParams{
320 OrgID: snap.OrgID,
321 StripeSubscriptionID: pgText(snap.StripeSubscriptionID),
322 ActiveMembers: int32(snap.ActiveMembers),
323 BillableSeats: int32(snap.BillableSeats),
324 Source: source,
325 })
326 if err != nil {
327 return billingdb.BillingSeatSnapshot{}, err
328 }
329 return billingdb.BillingSeatSnapshot(row), nil
330 }
331
332 func CountBillableOrgMembers(ctx context.Context, deps Deps, orgID int64) (int, error) {
333 if err := validateDeps(deps); err != nil {
334 return 0, err
335 }
336 if orgID == 0 {
337 return 0, ErrOrgIDRequired
338 }
339 n, err := billingdb.New().CountBillableOrgMembers(ctx, deps.Pool, orgID)
340 if err != nil {
341 return 0, err
342 }
343 return int(n), nil
344 }
345
346 func MarkPastDue(ctx context.Context, deps Deps, orgID int64, graceUntil time.Time, lastWebhookEventID string) (State, error) {
347 if err := validateDeps(deps); err != nil {
348 return State{}, err
349 }
350 if orgID == 0 {
351 return State{}, ErrOrgIDRequired
352 }
353 return billingdb.New().MarkPastDue(ctx, deps.Pool, billingdb.MarkPastDueParams{
354 OrgID: orgID,
355 GraceUntil: pgTime(graceUntil),
356 LastWebhookEventID: strings.TrimSpace(lastWebhookEventID),
357 })
358 }
359
360 func MarkCanceled(ctx context.Context, deps Deps, orgID int64, lastWebhookEventID string) (State, error) {
361 if err := validateDeps(deps); err != nil {
362 return State{}, err
363 }
364 if orgID == 0 {
365 return State{}, ErrOrgIDRequired
366 }
367 row, err := billingdb.New().MarkCanceled(ctx, deps.Pool, billingdb.MarkCanceledParams{
368 OrgID: orgID,
369 LastWebhookEventID: strings.TrimSpace(lastWebhookEventID),
370 })
371 if err != nil {
372 return State{}, err
373 }
374 return stateFromCanceled(row), nil
375 }
376
377 func ClearBillingLock(ctx context.Context, deps Deps, orgID int64) (State, error) {
378 if err := validateDeps(deps); err != nil {
379 return State{}, err
380 }
381 if orgID == 0 {
382 return State{}, ErrOrgIDRequired
383 }
384 row, err := billingdb.New().ClearBillingLock(ctx, deps.Pool, orgID)
385 if err != nil {
386 return State{}, err
387 }
388 return stateFromClear(row), nil
389 }
390
391 func validateDeps(deps Deps) error {
392 if deps.Pool == nil {
393 return ErrPoolRequired
394 }
395 return nil
396 }
397
398 func validPlan(plan Plan) bool {
399 switch plan {
400 case PlanFree, PlanTeam, PlanEnterprise:
401 return true
402 default:
403 return false
404 }
405 }
406
407 func validStatus(status SubscriptionStatus) bool {
408 switch status {
409 case SubscriptionStatusNone,
410 SubscriptionStatusIncomplete,
411 SubscriptionStatusTrialing,
412 SubscriptionStatusActive,
413 SubscriptionStatusPastDue,
414 SubscriptionStatusCanceled,
415 SubscriptionStatusUnpaid,
416 SubscriptionStatusPaused:
417 return true
418 default:
419 return false
420 }
421 }
422
423 func validInvoiceStatus(status InvoiceStatus) bool {
424 switch status {
425 case InvoiceStatusDraft,
426 InvoiceStatusOpen,
427 InvoiceStatusPaid,
428 InvoiceStatusVoid,
429 InvoiceStatusUncollectible:
430 return true
431 default:
432 return false
433 }
434 }
435
436 func pgText(s string) pgtype.Text {
437 s = strings.TrimSpace(s)
438 return pgtype.Text{String: s, Valid: s != ""}
439 }
440
441 func pgTime(t time.Time) pgtype.Timestamptz {
442 return pgtype.Timestamptz{Time: t, Valid: !t.IsZero()}
443 }
444
445 func jsonObject(payload []byte) bool {
446 var v map[string]any
447 return json.Unmarshal(payload, &v) == nil && v != nil
448 }
449
450 func stateFromApply(row billingdb.ApplySubscriptionSnapshotRow) State {
451 return State(row)
452 }
453
454 func stateFromCanceled(row billingdb.MarkCanceledRow) State {
455 return State(row)
456 }
457
458 func stateFromClear(row billingdb.ClearBillingLockRow) State {
459 return State(row)
460 }
461