@@ -1,5 +1,11 @@ |
| 1 | 1 | // SPDX-License-Identifier: AGPL-3.0-or-later |
| 2 | 2 | |
| 3 | +// PRO04 note: this file is now subject-agnostic — it routes Stripe |
| 4 | +// webhook events to either org or user billing state based on the |
| 5 | +// resolved Principal. The file still lives under `handlers/orgs/` |
| 6 | +// for wiring continuity; a follow-up sprint moves it to |
| 7 | +// `handlers/billing/` once the SP-only callers are gone. |
| 8 | + |
| 3 | 9 | package orgs |
| 4 | 10 | |
| 5 | 11 | import ( |
@@ -13,7 +19,6 @@ import ( |
| 13 | 19 | "strings" |
| 14 | 20 | "time" |
| 15 | 21 | |
| 16 | | - "github.com/jackc/pgx/v5" |
| 17 | 22 | stripeapi "github.com/stripe/stripe-go/v85" |
| 18 | 23 | |
| 19 | 24 | orgbilling "github.com/tenseleyFlow/shithub/internal/billing" |
@@ -89,35 +94,70 @@ func (h *Handlers) applyStripeCheckoutCompleted(ctx context.Context, event strip |
| 89 | 94 | if err := unmarshalStripeEventObject(event, &session); err != nil { |
| 90 | 95 | return err |
| 91 | 96 | } |
| 92 | | - orgID := stripeOrgIDFromMetadata(session.Metadata) |
| 93 | | - if orgID == 0 { |
| 94 | | - if id, err := strconv.ParseInt(strings.TrimSpace(session.ClientReferenceID), 10, 64); err == nil && id > 0 { |
| 95 | | - orgID = id |
| 96 | | - } |
| 97 | | - } |
| 98 | | - if orgID == 0 { |
| 99 | | - return errors.New("stripe checkout.session.completed missing shithub org metadata") |
| 100 | | - } |
| 101 | 97 | customerID := stripeCustomerID(session.Customer) |
| 102 | 98 | if customerID == "" { |
| 103 | 99 | return errors.New("stripe checkout.session.completed missing customer") |
| 104 | 100 | } |
| 105 | | - _, err := orgbilling.SetStripeCustomer(ctx, orgbilling.Deps{Pool: h.d.Pool}, orgID, customerID) |
| 101 | + principal, err := h.resolvePrincipalFromCheckout(ctx, &session, customerID) |
| 102 | + if err != nil { |
| 103 | + return err |
| 104 | + } |
| 105 | + _, err = orgbilling.SetStripeCustomerForPrincipal(ctx, orgbilling.Deps{Pool: h.d.Pool}, principal, customerID) |
| 106 | 106 | return err |
| 107 | 107 | } |
| 108 | 108 | |
| 109 | +// resolvePrincipalFromCheckout walks the resolution chain for a |
| 110 | +// checkout.session.completed event. Order matches the spec: |
| 111 | +// 1. metadata.shithub_subject_kind + shithub_subject_id (PRO04 path) |
| 112 | +// 2. metadata.shithub_org_id (legacy SP03 path) |
| 113 | +// 3. client_reference_id parsed as int (legacy SP03 path) |
| 114 | +// 4. customer-id lookup against both tables |
| 115 | +// |
| 116 | +// Any path that yields a Principal returns immediately; the |
| 117 | +// fall-through error covers events that can't be matched at all. |
| 118 | +func (h *Handlers) resolvePrincipalFromCheckout(ctx context.Context, session *stripeapi.CheckoutSession, customerID string) (orgbilling.Principal, error) { |
| 119 | + if p, ok := stripePrincipalFromMetadata(session.Metadata); ok { |
| 120 | + return p, nil |
| 121 | + } |
| 122 | + if orgID := stripeOrgIDFromMetadata(session.Metadata); orgID != 0 { |
| 123 | + return orgbilling.PrincipalForOrg(orgID), nil |
| 124 | + } |
| 125 | + if id, err := strconv.ParseInt(strings.TrimSpace(session.ClientReferenceID), 10, 64); err == nil && id > 0 { |
| 126 | + // Legacy client_reference_id is org-only by convention. |
| 127 | + return orgbilling.PrincipalForOrg(id), nil |
| 128 | + } |
| 129 | + if customerID != "" { |
| 130 | + state, err := orgbilling.ResolvePrincipalByStripeCustomer(ctx, orgbilling.Deps{Pool: h.d.Pool}, customerID) |
| 131 | + if err == nil { |
| 132 | + return state.Principal, nil |
| 133 | + } |
| 134 | + if !errors.Is(err, orgbilling.ErrPrincipalNotFound) { |
| 135 | + return orgbilling.Principal{}, err |
| 136 | + } |
| 137 | + } |
| 138 | + return orgbilling.Principal{}, errors.New("stripe checkout.session.completed missing shithub subject metadata") |
| 139 | +} |
| 140 | + |
| 109 | 141 | func (h *Handlers) applyStripeSubscriptionEvent(ctx context.Context, event stripeapi.Event) error { |
| 110 | 142 | var sub stripeapi.Subscription |
| 111 | 143 | if err := unmarshalStripeEventObject(event, &sub); err != nil { |
| 112 | 144 | return err |
| 113 | 145 | } |
| 114 | | - orgID, err := h.resolveOrgIDFromSubscription(ctx, &sub) |
| 146 | + principal, err := h.resolvePrincipalFromSubscription(ctx, &sub) |
| 115 | 147 | if err != nil { |
| 116 | 148 | return err |
| 117 | 149 | } |
| 150 | + // Cross-kind price-id check: if the subscription's first item |
| 151 | + // price doesn't match the expected price for the resolved kind, |
| 152 | + // refuse to apply. A Pro price on an org subject (or Team on |
| 153 | + // user) means metadata was misconfigured in the Stripe Dashboard; |
| 154 | + // silently applying would corrupt the wrong table. |
| 155 | + if err := h.guardPriceKindMatch(principal.Kind, &sub); err != nil { |
| 156 | + return err |
| 157 | + } |
| 118 | 158 | customerID := stripeCustomerID(sub.Customer) |
| 119 | 159 | if customerID != "" { |
| 120 | | - if _, err := orgbilling.SetStripeCustomer(ctx, orgbilling.Deps{Pool: h.d.Pool}, orgID, customerID); err != nil { |
| 160 | + if _, err := orgbilling.SetStripeCustomerForPrincipal(ctx, orgbilling.Deps{Pool: h.d.Pool}, principal, customerID); err != nil { |
| 121 | 161 | return err |
| 122 | 162 | } |
| 123 | 163 | } |
@@ -126,14 +166,13 @@ func (h *Handlers) applyStripeSubscriptionEvent(ctx context.Context, event strip |
| 126 | 166 | return err |
| 127 | 167 | } |
| 128 | 168 | if status == orgbilling.SubscriptionStatusCanceled || string(event.Type) == "customer.subscription.deleted" { |
| 129 | | - _, err := orgbilling.MarkCanceled(ctx, orgbilling.Deps{Pool: h.d.Pool}, orgID, event.ID) |
| 169 | + _, err := orgbilling.MarkCanceledForPrincipal(ctx, orgbilling.Deps{Pool: h.d.Pool}, principal, event.ID) |
| 130 | 170 | return err |
| 131 | 171 | } |
| 132 | 172 | itemID := stripeSubscriptionItemID(sub.Items) |
| 133 | 173 | periodStart, periodEnd := stripeSubscriptionPeriod(sub.Items) |
| 134 | | - _, err = orgbilling.ApplySubscriptionSnapshot(ctx, orgbilling.Deps{Pool: h.d.Pool}, orgbilling.SubscriptionSnapshot{ |
| 135 | | - OrgID: orgID, |
| 136 | | - Plan: orgbilling.PlanTeam, |
| 174 | + _, err = orgbilling.ApplySubscriptionSnapshotForPrincipal(ctx, orgbilling.Deps{Pool: h.d.Pool}, orgbilling.PrincipalSubscriptionSnapshot{ |
| 175 | + Principal: principal, |
| 137 | 176 | Status: status, |
| 138 | 177 | StripeSubscriptionID: strings.TrimSpace(sub.ID), |
| 139 | 178 | StripeSubscriptionItemID: itemID, |
@@ -147,12 +186,79 @@ func (h *Handlers) applyStripeSubscriptionEvent(ctx context.Context, event strip |
| 147 | 186 | return err |
| 148 | 187 | } |
| 149 | 188 | |
| 189 | +// resolvePrincipalFromSubscription walks the same chain as the |
| 190 | +// checkout resolver but starts from a subscription object. |
| 191 | +func (h *Handlers) resolvePrincipalFromSubscription(ctx context.Context, sub *stripeapi.Subscription) (orgbilling.Principal, error) { |
| 192 | + if p, ok := stripePrincipalFromMetadata(sub.Metadata); ok { |
| 193 | + return p, nil |
| 194 | + } |
| 195 | + if orgID := stripeOrgIDFromMetadata(sub.Metadata); orgID != 0 { |
| 196 | + return orgbilling.PrincipalForOrg(orgID), nil |
| 197 | + } |
| 198 | + if customerID := stripeCustomerID(sub.Customer); customerID != "" { |
| 199 | + state, err := orgbilling.ResolvePrincipalByStripeCustomer(ctx, orgbilling.Deps{Pool: h.d.Pool}, customerID) |
| 200 | + if err == nil { |
| 201 | + return state.Principal, nil |
| 202 | + } |
| 203 | + if !errors.Is(err, orgbilling.ErrPrincipalNotFound) { |
| 204 | + return orgbilling.Principal{}, err |
| 205 | + } |
| 206 | + } |
| 207 | + if subID := strings.TrimSpace(sub.ID); subID != "" { |
| 208 | + state, err := orgbilling.ResolvePrincipalByStripeSubscription(ctx, orgbilling.Deps{Pool: h.d.Pool}, subID) |
| 209 | + if err == nil { |
| 210 | + return state.Principal, nil |
| 211 | + } |
| 212 | + if !errors.Is(err, orgbilling.ErrPrincipalNotFound) { |
| 213 | + return orgbilling.Principal{}, err |
| 214 | + } |
| 215 | + } |
| 216 | + return orgbilling.Principal{}, errors.New("stripe subscription does not map to a shithub subject") |
| 217 | +} |
| 218 | + |
| 219 | +// guardPriceKindMatch refuses to apply a subscription when the |
| 220 | +// price-id on its first line item doesn't match the expected price |
| 221 | +// for the resolved subject kind. Catches dashboard-side |
| 222 | +// misconfiguration before it writes the wrong table. |
| 223 | +// |
| 224 | +// The check requires the handler to know which price-id is Pro and |
| 225 | +// which is Team — that wiring lands via BillingPriceIDs(); a |
| 226 | +// non-configured client (Pro disabled) skips the check rather than |
| 227 | +// rejecting org events. PRO-disabled instances never see Pro |
| 228 | +// events, so the org path is unaffected. |
| 229 | +func (h *Handlers) guardPriceKindMatch(kind orgbilling.SubjectKind, sub *stripeapi.Subscription) error { |
| 230 | + if sub == nil || sub.Items == nil || len(sub.Items.Data) == 0 || sub.Items.Data[0] == nil || sub.Items.Data[0].Price == nil { |
| 231 | + // No price on the event — nothing to validate. Subsequent |
| 232 | + // apply logic surfaces the missing-data error if needed. |
| 233 | + return nil |
| 234 | + } |
| 235 | + priceID := strings.TrimSpace(sub.Items.Data[0].Price.ID) |
| 236 | + teamPrice, proPrice := h.d.BillingPriceIDs() |
| 237 | + switch kind { |
| 238 | + case orgbilling.SubjectKindOrg: |
| 239 | + if teamPrice != "" && priceID != "" && priceID != teamPrice { |
| 240 | + if priceID == proPrice { |
| 241 | + return fmt.Errorf("stripe subscription: Pro price %q applied to org subject — metadata likely misconfigured", priceID) |
| 242 | + } |
| 243 | + return fmt.Errorf("stripe subscription: price %q does not match expected team price %q for org subject", priceID, teamPrice) |
| 244 | + } |
| 245 | + case orgbilling.SubjectKindUser: |
| 246 | + if proPrice != "" && priceID != "" && priceID != proPrice { |
| 247 | + if priceID == teamPrice { |
| 248 | + return fmt.Errorf("stripe subscription: Team price %q applied to user subject — metadata likely misconfigured", priceID) |
| 249 | + } |
| 250 | + return fmt.Errorf("stripe subscription: price %q does not match expected pro price %q for user subject", priceID, proPrice) |
| 251 | + } |
| 252 | + } |
| 253 | + return nil |
| 254 | +} |
| 255 | + |
| 150 | 256 | func (h *Handlers) applyStripeInvoiceEvent(ctx context.Context, event stripeapi.Event) error { |
| 151 | 257 | var inv stripeapi.Invoice |
| 152 | 258 | if err := unmarshalStripeEventObject(event, &inv); err != nil { |
| 153 | 259 | return err |
| 154 | 260 | } |
| 155 | | - orgID, state, err := h.resolveOrgStateFromInvoice(ctx, &inv) |
| 261 | + principalState, err := h.resolvePrincipalStateFromInvoice(ctx, &inv) |
| 156 | 262 | if err != nil { |
| 157 | 263 | return err |
| 158 | 264 | } |
@@ -160,8 +266,7 @@ func (h *Handlers) applyStripeInvoiceEvent(ctx context.Context, event stripeapi. |
| 160 | 266 | if err != nil { |
| 161 | 267 | return err |
| 162 | 268 | } |
| 163 | | - if _, err := orgbilling.UpsertInvoice(ctx, orgbilling.Deps{Pool: h.d.Pool}, orgbilling.InvoiceSnapshot{ |
| 164 | | - OrgID: orgID, |
| 269 | + if _, err := orgbilling.UpsertInvoiceForPrincipal(ctx, orgbilling.Deps{Pool: h.d.Pool}, principalState.Principal, orgbilling.InvoiceSnapshot{ |
| 165 | 270 | StripeInvoiceID: strings.TrimSpace(inv.ID), |
| 166 | 271 | StripeCustomerID: stripeCustomerID(inv.Customer), |
| 167 | 272 | StripeSubscriptionID: stripeInvoiceSubscriptionID(&inv), |
@@ -184,64 +289,67 @@ func (h *Handlers) applyStripeInvoiceEvent(ctx context.Context, event stripeapi. |
| 184 | 289 | switch string(event.Type) { |
| 185 | 290 | case "invoice.payment_failed": |
| 186 | 291 | graceUntil := time.Now().UTC().Add(h.d.BillingGracePeriod) |
| 187 | | - _, err := orgbilling.MarkPastDue(ctx, orgbilling.Deps{Pool: h.d.Pool}, orgID, graceUntil, event.ID) |
| 292 | + _, err := orgbilling.MarkPastDueForPrincipal(ctx, orgbilling.Deps{Pool: h.d.Pool}, principalState.Principal, graceUntil, event.ID) |
| 188 | 293 | return err |
| 189 | 294 | case "invoice.payment_succeeded": |
| 190 | | - if state.SubscriptionStatus != orgbilling.SubscriptionStatusCanceled { |
| 191 | | - _, err := orgbilling.MarkPaymentSucceeded(ctx, orgbilling.Deps{Pool: h.d.Pool}, orgID, event.ID) |
| 295 | + if principalState.SubscriptionStatus != orgbilling.SubscriptionStatusCanceled { |
| 296 | + _, err := orgbilling.MarkPaymentSucceededForPrincipal(ctx, orgbilling.Deps{Pool: h.d.Pool}, principalState.Principal, event.ID) |
| 192 | 297 | return err |
| 193 | 298 | } |
| 194 | 299 | } |
| 195 | 300 | return nil |
| 196 | 301 | } |
| 197 | 302 | |
| 198 | | -func (h *Handlers) resolveOrgIDFromSubscription(ctx context.Context, sub *stripeapi.Subscription) (int64, error) { |
| 199 | | - if orgID := stripeOrgIDFromMetadata(sub.Metadata); orgID != 0 { |
| 200 | | - return orgID, nil |
| 201 | | - } |
| 202 | | - if customerID := stripeCustomerID(sub.Customer); customerID != "" { |
| 203 | | - state, err := orgbilling.GetOrgBillingStateByStripeCustomer(ctx, orgbilling.Deps{Pool: h.d.Pool}, customerID) |
| 204 | | - if err == nil { |
| 205 | | - return state.OrgID, nil |
| 206 | | - } |
| 207 | | - if !errors.Is(err, pgx.ErrNoRows) { |
| 208 | | - return 0, err |
| 209 | | - } |
| 210 | | - } |
| 211 | | - if subID := strings.TrimSpace(sub.ID); subID != "" { |
| 212 | | - state, err := orgbilling.GetOrgBillingStateByStripeSubscription(ctx, orgbilling.Deps{Pool: h.d.Pool}, subID) |
| 213 | | - if err == nil { |
| 214 | | - return state.OrgID, nil |
| 215 | | - } |
| 216 | | - if !errors.Is(err, pgx.ErrNoRows) { |
| 217 | | - return 0, err |
| 218 | | - } |
| 219 | | - } |
| 220 | | - return 0, errors.New("stripe subscription does not map to a shithub organization") |
| 221 | | -} |
| 222 | | - |
| 223 | | -func (h *Handlers) resolveOrgStateFromInvoice(ctx context.Context, inv *stripeapi.Invoice) (int64, orgbilling.State, error) { |
| 303 | +// resolvePrincipalStateFromInvoice resolves Principal AND fetches |
| 304 | +// the current billing state in one shot — the apply branch needs |
| 305 | +// the SubscriptionStatus to decide whether to flip payment- |
| 306 | +// succeeded transitions. Mirrors the legacy |
| 307 | +// resolveOrgStateFromInvoice but returns a kind-tagged Principal. |
| 308 | +func (h *Handlers) resolvePrincipalStateFromInvoice(ctx context.Context, inv *stripeapi.Invoice) (orgbilling.PrincipalState, error) { |
| 224 | 309 | if customerID := stripeCustomerID(inv.Customer); customerID != "" { |
| 225 | | - state, err := orgbilling.GetOrgBillingStateByStripeCustomer(ctx, orgbilling.Deps{Pool: h.d.Pool}, customerID) |
| 310 | + state, err := orgbilling.ResolvePrincipalByStripeCustomer(ctx, orgbilling.Deps{Pool: h.d.Pool}, customerID) |
| 226 | 311 | if err == nil { |
| 227 | | - return state.OrgID, state, nil |
| 312 | + return state, nil |
| 228 | 313 | } |
| 229 | | - if !errors.Is(err, pgx.ErrNoRows) { |
| 230 | | - return 0, orgbilling.State{}, err |
| 314 | + if !errors.Is(err, orgbilling.ErrPrincipalNotFound) { |
| 315 | + return orgbilling.PrincipalState{}, err |
| 231 | 316 | } |
| 232 | 317 | } |
| 233 | 318 | if subID := stripeInvoiceSubscriptionID(inv); subID != "" { |
| 234 | | - state, err := orgbilling.GetOrgBillingStateByStripeSubscription(ctx, orgbilling.Deps{Pool: h.d.Pool}, subID) |
| 319 | + state, err := orgbilling.ResolvePrincipalByStripeSubscription(ctx, orgbilling.Deps{Pool: h.d.Pool}, subID) |
| 235 | 320 | if err == nil { |
| 236 | | - return state.OrgID, state, nil |
| 321 | + return state, nil |
| 237 | 322 | } |
| 238 | | - if !errors.Is(err, pgx.ErrNoRows) { |
| 239 | | - return 0, orgbilling.State{}, err |
| 323 | + if !errors.Is(err, orgbilling.ErrPrincipalNotFound) { |
| 324 | + return orgbilling.PrincipalState{}, err |
| 240 | 325 | } |
| 241 | 326 | } |
| 242 | | - return 0, orgbilling.State{}, errors.New("stripe invoice does not map to a shithub organization") |
| 327 | + return orgbilling.PrincipalState{}, errors.New("stripe invoice does not map to a shithub subject") |
| 328 | +} |
| 329 | + |
| 330 | +// stripePrincipalFromMetadata reads the PRO04 subject metadata |
| 331 | +// keys. Returns ok=false when either key is missing or malformed — |
| 332 | +// the caller falls through to the legacy resolution chain. |
| 333 | +func stripePrincipalFromMetadata(metadata map[string]string) (orgbilling.Principal, bool) { |
| 334 | + if len(metadata) == 0 { |
| 335 | + return orgbilling.Principal{}, false |
| 336 | + } |
| 337 | + kind := orgbilling.SubjectKind(strings.TrimSpace(metadata[stripebilling.MetadataSubjectKind])) |
| 338 | + if !kind.Valid() { |
| 339 | + return orgbilling.Principal{}, false |
| 340 | + } |
| 341 | + rawID := strings.TrimSpace(metadata[stripebilling.MetadataSubjectID]) |
| 342 | + id, err := strconv.ParseInt(rawID, 10, 64) |
| 343 | + if err != nil || id <= 0 { |
| 344 | + return orgbilling.Principal{}, false |
| 345 | + } |
| 346 | + return orgbilling.Principal{Kind: kind, ID: id}, true |
| 243 | 347 | } |
| 244 | 348 | |
| 349 | +// stripeOrgIDFromMetadata reads the legacy SP03 metadata key. |
| 350 | +// PRO04 keeps it for backward compatibility — existing org |
| 351 | +// subscriptions stamped before PRO04 deployed carry only this |
| 352 | +// key. Resolvers try the PRO04 keys first, fall back to this. |
| 245 | 353 | func stripeOrgIDFromMetadata(metadata map[string]string) int64 { |
| 246 | 354 | raw := strings.TrimSpace(metadata[stripebilling.MetadataOrgID]) |
| 247 | 355 | if raw == "" { |