Go · 9574 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 // Package billing owns local paid-organization state. It stores Stripe
4 // identifiers and derived subscription state, but it does not call
5 // Stripe directly; webhook/API integration lands in SP03.
6 package billing
7
8 import (
9 "context"
10 "encoding/json"
11 "errors"
12 "fmt"
13 "strings"
14 "time"
15
16 "github.com/jackc/pgx/v5"
17 "github.com/jackc/pgx/v5/pgtype"
18 "github.com/jackc/pgx/v5/pgxpool"
19
20 billingdb "github.com/tenseleyFlow/shithub/internal/billing/sqlc"
21 )
22
23 type Deps struct {
24 Pool *pgxpool.Pool
25 }
26
27 type (
28 Plan = billingdb.OrgPlan
29 SubscriptionStatus = billingdb.BillingSubscriptionStatus
30 State = billingdb.OrgBillingState
31 )
32
33 const (
34 PlanFree = billingdb.OrgPlanFree
35 PlanTeam = billingdb.OrgPlanTeam
36 PlanEnterprise = billingdb.OrgPlanEnterprise
37
38 SubscriptionStatusNone = billingdb.BillingSubscriptionStatusNone
39 SubscriptionStatusIncomplete = billingdb.BillingSubscriptionStatusIncomplete
40 SubscriptionStatusTrialing = billingdb.BillingSubscriptionStatusTrialing
41 SubscriptionStatusActive = billingdb.BillingSubscriptionStatusActive
42 SubscriptionStatusPastDue = billingdb.BillingSubscriptionStatusPastDue
43 SubscriptionStatusCanceled = billingdb.BillingSubscriptionStatusCanceled
44 SubscriptionStatusUnpaid = billingdb.BillingSubscriptionStatusUnpaid
45 SubscriptionStatusPaused = billingdb.BillingSubscriptionStatusPaused
46 )
47
48 var (
49 ErrPoolRequired = errors.New("billing: pool is required")
50 ErrOrgIDRequired = errors.New("billing: org id is required")
51 ErrStripeCustomerID = errors.New("billing: stripe customer id is required")
52 ErrInvalidPlan = errors.New("billing: invalid plan")
53 ErrInvalidStatus = errors.New("billing: invalid subscription status")
54 ErrInvalidSeatCount = errors.New("billing: seat counts cannot be negative")
55 ErrWebhookEventID = errors.New("billing: webhook event id is required")
56 ErrWebhookEventType = errors.New("billing: webhook event type is required")
57 ErrWebhookPayload = errors.New("billing: webhook payload must be a JSON object")
58 )
59
60 // SubscriptionSnapshot is the local projection of a provider
61 // subscription event. Provider-specific conversion belongs in SP03.
62 type SubscriptionSnapshot struct {
63 OrgID int64
64 Plan Plan
65 Status SubscriptionStatus
66 StripeSubscriptionID string
67 StripeSubscriptionItemID string
68 CurrentPeriodStart time.Time
69 CurrentPeriodEnd time.Time
70 CancelAtPeriodEnd bool
71 TrialEnd time.Time
72 CanceledAt time.Time
73 LastWebhookEventID string
74 }
75
76 type SeatSnapshot struct {
77 OrgID int64
78 StripeSubscriptionID string
79 ActiveMembers int
80 BillableSeats int
81 Source string
82 }
83
84 type WebhookEvent struct {
85 ProviderEventID string
86 EventType string
87 APIVersion string
88 Payload []byte
89 }
90
91 func GetOrgBillingState(ctx context.Context, deps Deps, orgID int64) (State, error) {
92 if err := validateDeps(deps); err != nil {
93 return State{}, err
94 }
95 if orgID == 0 {
96 return State{}, ErrOrgIDRequired
97 }
98 return billingdb.New().GetOrgBillingState(ctx, deps.Pool, orgID)
99 }
100
101 func SetStripeCustomer(ctx context.Context, deps Deps, orgID int64, customerID string) (State, error) {
102 if err := validateDeps(deps); err != nil {
103 return State{}, err
104 }
105 if orgID == 0 {
106 return State{}, ErrOrgIDRequired
107 }
108 customerID = strings.TrimSpace(customerID)
109 if customerID == "" {
110 return State{}, ErrStripeCustomerID
111 }
112 return billingdb.New().SetStripeCustomer(ctx, deps.Pool, billingdb.SetStripeCustomerParams{
113 OrgID: orgID,
114 StripeCustomerID: pgText(customerID),
115 })
116 }
117
118 func ApplySubscriptionSnapshot(ctx context.Context, deps Deps, snap SubscriptionSnapshot) (State, error) {
119 if err := validateDeps(deps); err != nil {
120 return State{}, err
121 }
122 if snap.OrgID == 0 {
123 return State{}, ErrOrgIDRequired
124 }
125 if !validPlan(snap.Plan) {
126 return State{}, fmt.Errorf("%w: %q", ErrInvalidPlan, snap.Plan)
127 }
128 if !validStatus(snap.Status) {
129 return State{}, fmt.Errorf("%w: %q", ErrInvalidStatus, snap.Status)
130 }
131 row, err := billingdb.New().ApplySubscriptionSnapshot(ctx, deps.Pool, billingdb.ApplySubscriptionSnapshotParams{
132 OrgID: snap.OrgID,
133 Plan: snap.Plan,
134 SubscriptionStatus: snap.Status,
135 StripeSubscriptionID: pgText(snap.StripeSubscriptionID),
136 StripeSubscriptionItemID: pgText(snap.StripeSubscriptionItemID),
137 CurrentPeriodStart: pgTime(snap.CurrentPeriodStart),
138 CurrentPeriodEnd: pgTime(snap.CurrentPeriodEnd),
139 CancelAtPeriodEnd: snap.CancelAtPeriodEnd,
140 TrialEnd: pgTime(snap.TrialEnd),
141 CanceledAt: pgTime(snap.CanceledAt),
142 LastWebhookEventID: strings.TrimSpace(snap.LastWebhookEventID),
143 })
144 if err != nil {
145 return State{}, err
146 }
147 return stateFromApply(row), nil
148 }
149
150 func RecordWebhookEvent(ctx context.Context, deps Deps, event WebhookEvent) (billingdb.BillingWebhookEvent, bool, error) {
151 if err := validateDeps(deps); err != nil {
152 return billingdb.BillingWebhookEvent{}, false, err
153 }
154 event.ProviderEventID = strings.TrimSpace(event.ProviderEventID)
155 event.EventType = strings.TrimSpace(event.EventType)
156 event.APIVersion = strings.TrimSpace(event.APIVersion)
157 if event.ProviderEventID == "" {
158 return billingdb.BillingWebhookEvent{}, false, ErrWebhookEventID
159 }
160 if event.EventType == "" {
161 return billingdb.BillingWebhookEvent{}, false, ErrWebhookEventType
162 }
163 if !jsonObject(event.Payload) {
164 return billingdb.BillingWebhookEvent{}, false, ErrWebhookPayload
165 }
166 row, err := billingdb.New().CreateWebhookEventReceipt(ctx, deps.Pool, billingdb.CreateWebhookEventReceiptParams{
167 ProviderEventID: event.ProviderEventID,
168 EventType: event.EventType,
169 ApiVersion: event.APIVersion,
170 Payload: event.Payload,
171 })
172 if err != nil {
173 if errors.Is(err, pgx.ErrNoRows) {
174 return billingdb.BillingWebhookEvent{}, false, nil
175 }
176 return billingdb.BillingWebhookEvent{}, false, err
177 }
178 return row, true, nil
179 }
180
181 func SyncSeatSnapshot(ctx context.Context, deps Deps, snap SeatSnapshot) (billingdb.BillingSeatSnapshot, error) {
182 if err := validateDeps(deps); err != nil {
183 return billingdb.BillingSeatSnapshot{}, err
184 }
185 if snap.OrgID == 0 {
186 return billingdb.BillingSeatSnapshot{}, ErrOrgIDRequired
187 }
188 if snap.ActiveMembers < 0 || snap.BillableSeats < 0 {
189 return billingdb.BillingSeatSnapshot{}, ErrInvalidSeatCount
190 }
191 source := strings.TrimSpace(snap.Source)
192 if source == "" {
193 source = "local"
194 }
195 row, err := billingdb.New().CreateSeatSnapshot(ctx, deps.Pool, billingdb.CreateSeatSnapshotParams{
196 OrgID: snap.OrgID,
197 StripeSubscriptionID: pgText(snap.StripeSubscriptionID),
198 ActiveMembers: int32(snap.ActiveMembers),
199 BillableSeats: int32(snap.BillableSeats),
200 Source: source,
201 })
202 if err != nil {
203 return billingdb.BillingSeatSnapshot{}, err
204 }
205 return billingdb.BillingSeatSnapshot(row), nil
206 }
207
208 func MarkPastDue(ctx context.Context, deps Deps, orgID int64, graceUntil time.Time, lastWebhookEventID string) (State, error) {
209 if err := validateDeps(deps); err != nil {
210 return State{}, err
211 }
212 if orgID == 0 {
213 return State{}, ErrOrgIDRequired
214 }
215 return billingdb.New().MarkPastDue(ctx, deps.Pool, billingdb.MarkPastDueParams{
216 OrgID: orgID,
217 GraceUntil: pgTime(graceUntil),
218 LastWebhookEventID: strings.TrimSpace(lastWebhookEventID),
219 })
220 }
221
222 func MarkCanceled(ctx context.Context, deps Deps, orgID int64, lastWebhookEventID string) (State, error) {
223 if err := validateDeps(deps); err != nil {
224 return State{}, err
225 }
226 if orgID == 0 {
227 return State{}, ErrOrgIDRequired
228 }
229 row, err := billingdb.New().MarkCanceled(ctx, deps.Pool, billingdb.MarkCanceledParams{
230 OrgID: orgID,
231 LastWebhookEventID: strings.TrimSpace(lastWebhookEventID),
232 })
233 if err != nil {
234 return State{}, err
235 }
236 return stateFromCanceled(row), nil
237 }
238
239 func ClearBillingLock(ctx context.Context, deps Deps, orgID int64) (State, error) {
240 if err := validateDeps(deps); err != nil {
241 return State{}, err
242 }
243 if orgID == 0 {
244 return State{}, ErrOrgIDRequired
245 }
246 row, err := billingdb.New().ClearBillingLock(ctx, deps.Pool, orgID)
247 if err != nil {
248 return State{}, err
249 }
250 return stateFromClear(row), nil
251 }
252
253 func validateDeps(deps Deps) error {
254 if deps.Pool == nil {
255 return ErrPoolRequired
256 }
257 return nil
258 }
259
260 func validPlan(plan Plan) bool {
261 switch plan {
262 case PlanFree, PlanTeam, PlanEnterprise:
263 return true
264 default:
265 return false
266 }
267 }
268
269 func validStatus(status SubscriptionStatus) bool {
270 switch status {
271 case SubscriptionStatusNone,
272 SubscriptionStatusIncomplete,
273 SubscriptionStatusTrialing,
274 SubscriptionStatusActive,
275 SubscriptionStatusPastDue,
276 SubscriptionStatusCanceled,
277 SubscriptionStatusUnpaid,
278 SubscriptionStatusPaused:
279 return true
280 default:
281 return false
282 }
283 }
284
285 func pgText(s string) pgtype.Text {
286 s = strings.TrimSpace(s)
287 return pgtype.Text{String: s, Valid: s != ""}
288 }
289
290 func pgTime(t time.Time) pgtype.Timestamptz {
291 return pgtype.Timestamptz{Time: t, Valid: !t.IsZero()}
292 }
293
294 func jsonObject(payload []byte) bool {
295 var v map[string]any
296 return json.Unmarshal(payload, &v) == nil && v != nil
297 }
298
299 func stateFromApply(row billingdb.ApplySubscriptionSnapshotRow) State {
300 return State(row)
301 }
302
303 func stateFromCanceled(row billingdb.MarkCanceledRow) State {
304 return State(row)
305 }
306
307 func stateFromClear(row billingdb.ClearBillingLockRow) State {
308 return State(row)
309 }
310