tenseleyflow/shithub / 54c4e32

Browse files

web/orgs: subject-aware webhook with Principal resolution + cross-kind misroute guard

Authored by mfwolffe <wolffemf@dukes.jmu.edu>
SHA
54c4e32f410b8933163d101eb4ddc3ba21dbec82
Parents
25b797f
Tree
ff06f75

2 changed files

StatusFile+-
M internal/web/handlers/orgs/billing_webhook.go 166 58
A internal/web/handlers/orgs/billing_webhook_resolve_test.go 125 0
internal/web/handlers/orgs/billing_webhook.gomodified
@@ -1,5 +1,11 @@
11
 // SPDX-License-Identifier: AGPL-3.0-or-later
22
 
3
+// PRO04 note: this file is now subject-agnostic — it routes Stripe
4
+// webhook events to either org or user billing state based on the
5
+// resolved Principal. The file still lives under `handlers/orgs/`
6
+// for wiring continuity; a follow-up sprint moves it to
7
+// `handlers/billing/` once the SP-only callers are gone.
8
+
39
 package orgs
410
 
511
 import (
@@ -13,7 +19,6 @@ import (
1319
 	"strings"
1420
 	"time"
1521
 
16
-	"github.com/jackc/pgx/v5"
1722
 	stripeapi "github.com/stripe/stripe-go/v85"
1823
 
1924
 	orgbilling "github.com/tenseleyFlow/shithub/internal/billing"
@@ -89,35 +94,70 @@ func (h *Handlers) applyStripeCheckoutCompleted(ctx context.Context, event strip
8994
 	if err := unmarshalStripeEventObject(event, &session); err != nil {
9095
 		return err
9196
 	}
92
-	orgID := stripeOrgIDFromMetadata(session.Metadata)
93
-	if orgID == 0 {
94
-		if id, err := strconv.ParseInt(strings.TrimSpace(session.ClientReferenceID), 10, 64); err == nil && id > 0 {
95
-			orgID = id
96
-		}
97
-	}
98
-	if orgID == 0 {
99
-		return errors.New("stripe checkout.session.completed missing shithub org metadata")
100
-	}
10197
 	customerID := stripeCustomerID(session.Customer)
10298
 	if customerID == "" {
10399
 		return errors.New("stripe checkout.session.completed missing customer")
104100
 	}
105
-	_, err := orgbilling.SetStripeCustomer(ctx, orgbilling.Deps{Pool: h.d.Pool}, orgID, customerID)
101
+	principal, err := h.resolvePrincipalFromCheckout(ctx, &session, customerID)
102
+	if err != nil {
103
+		return err
104
+	}
105
+	_, err = orgbilling.SetStripeCustomerForPrincipal(ctx, orgbilling.Deps{Pool: h.d.Pool}, principal, customerID)
106106
 	return err
107107
 }
108108
 
109
+// resolvePrincipalFromCheckout walks the resolution chain for a
110
+// checkout.session.completed event. Order matches the spec:
111
+//  1. metadata.shithub_subject_kind + shithub_subject_id (PRO04 path)
112
+//  2. metadata.shithub_org_id (legacy SP03 path)
113
+//  3. client_reference_id parsed as int (legacy SP03 path)
114
+//  4. customer-id lookup against both tables
115
+//
116
+// Any path that yields a Principal returns immediately; the
117
+// fall-through error covers events that can't be matched at all.
118
+func (h *Handlers) resolvePrincipalFromCheckout(ctx context.Context, session *stripeapi.CheckoutSession, customerID string) (orgbilling.Principal, error) {
119
+	if p, ok := stripePrincipalFromMetadata(session.Metadata); ok {
120
+		return p, nil
121
+	}
122
+	if orgID := stripeOrgIDFromMetadata(session.Metadata); orgID != 0 {
123
+		return orgbilling.PrincipalForOrg(orgID), nil
124
+	}
125
+	if id, err := strconv.ParseInt(strings.TrimSpace(session.ClientReferenceID), 10, 64); err == nil && id > 0 {
126
+		// Legacy client_reference_id is org-only by convention.
127
+		return orgbilling.PrincipalForOrg(id), nil
128
+	}
129
+	if customerID != "" {
130
+		state, err := orgbilling.ResolvePrincipalByStripeCustomer(ctx, orgbilling.Deps{Pool: h.d.Pool}, customerID)
131
+		if err == nil {
132
+			return state.Principal, nil
133
+		}
134
+		if !errors.Is(err, orgbilling.ErrPrincipalNotFound) {
135
+			return orgbilling.Principal{}, err
136
+		}
137
+	}
138
+	return orgbilling.Principal{}, errors.New("stripe checkout.session.completed missing shithub subject metadata")
139
+}
140
+
109141
 func (h *Handlers) applyStripeSubscriptionEvent(ctx context.Context, event stripeapi.Event) error {
110142
 	var sub stripeapi.Subscription
111143
 	if err := unmarshalStripeEventObject(event, &sub); err != nil {
112144
 		return err
113145
 	}
114
-	orgID, err := h.resolveOrgIDFromSubscription(ctx, &sub)
146
+	principal, err := h.resolvePrincipalFromSubscription(ctx, &sub)
115147
 	if err != nil {
116148
 		return err
117149
 	}
150
+	// Cross-kind price-id check: if the subscription's first item
151
+	// price doesn't match the expected price for the resolved kind,
152
+	// refuse to apply. A Pro price on an org subject (or Team on
153
+	// user) means metadata was misconfigured in the Stripe Dashboard;
154
+	// silently applying would corrupt the wrong table.
155
+	if err := h.guardPriceKindMatch(principal.Kind, &sub); err != nil {
156
+		return err
157
+	}
118158
 	customerID := stripeCustomerID(sub.Customer)
119159
 	if customerID != "" {
120
-		if _, err := orgbilling.SetStripeCustomer(ctx, orgbilling.Deps{Pool: h.d.Pool}, orgID, customerID); err != nil {
160
+		if _, err := orgbilling.SetStripeCustomerForPrincipal(ctx, orgbilling.Deps{Pool: h.d.Pool}, principal, customerID); err != nil {
121161
 			return err
122162
 		}
123163
 	}
@@ -126,14 +166,13 @@ func (h *Handlers) applyStripeSubscriptionEvent(ctx context.Context, event strip
126166
 		return err
127167
 	}
128168
 	if status == orgbilling.SubscriptionStatusCanceled || string(event.Type) == "customer.subscription.deleted" {
129
-		_, err := orgbilling.MarkCanceled(ctx, orgbilling.Deps{Pool: h.d.Pool}, orgID, event.ID)
169
+		_, err := orgbilling.MarkCanceledForPrincipal(ctx, orgbilling.Deps{Pool: h.d.Pool}, principal, event.ID)
130170
 		return err
131171
 	}
132172
 	itemID := stripeSubscriptionItemID(sub.Items)
133173
 	periodStart, periodEnd := stripeSubscriptionPeriod(sub.Items)
134
-	_, err = orgbilling.ApplySubscriptionSnapshot(ctx, orgbilling.Deps{Pool: h.d.Pool}, orgbilling.SubscriptionSnapshot{
135
-		OrgID:                    orgID,
136
-		Plan:                     orgbilling.PlanTeam,
174
+	_, err = orgbilling.ApplySubscriptionSnapshotForPrincipal(ctx, orgbilling.Deps{Pool: h.d.Pool}, orgbilling.PrincipalSubscriptionSnapshot{
175
+		Principal:                principal,
137176
 		Status:                   status,
138177
 		StripeSubscriptionID:     strings.TrimSpace(sub.ID),
139178
 		StripeSubscriptionItemID: itemID,
@@ -147,12 +186,79 @@ func (h *Handlers) applyStripeSubscriptionEvent(ctx context.Context, event strip
147186
 	return err
148187
 }
149188
 
189
+// resolvePrincipalFromSubscription walks the same chain as the
190
+// checkout resolver but starts from a subscription object.
191
+func (h *Handlers) resolvePrincipalFromSubscription(ctx context.Context, sub *stripeapi.Subscription) (orgbilling.Principal, error) {
192
+	if p, ok := stripePrincipalFromMetadata(sub.Metadata); ok {
193
+		return p, nil
194
+	}
195
+	if orgID := stripeOrgIDFromMetadata(sub.Metadata); orgID != 0 {
196
+		return orgbilling.PrincipalForOrg(orgID), nil
197
+	}
198
+	if customerID := stripeCustomerID(sub.Customer); customerID != "" {
199
+		state, err := orgbilling.ResolvePrincipalByStripeCustomer(ctx, orgbilling.Deps{Pool: h.d.Pool}, customerID)
200
+		if err == nil {
201
+			return state.Principal, nil
202
+		}
203
+		if !errors.Is(err, orgbilling.ErrPrincipalNotFound) {
204
+			return orgbilling.Principal{}, err
205
+		}
206
+	}
207
+	if subID := strings.TrimSpace(sub.ID); subID != "" {
208
+		state, err := orgbilling.ResolvePrincipalByStripeSubscription(ctx, orgbilling.Deps{Pool: h.d.Pool}, subID)
209
+		if err == nil {
210
+			return state.Principal, nil
211
+		}
212
+		if !errors.Is(err, orgbilling.ErrPrincipalNotFound) {
213
+			return orgbilling.Principal{}, err
214
+		}
215
+	}
216
+	return orgbilling.Principal{}, errors.New("stripe subscription does not map to a shithub subject")
217
+}
218
+
219
+// guardPriceKindMatch refuses to apply a subscription when the
220
+// price-id on its first line item doesn't match the expected price
221
+// for the resolved subject kind. Catches dashboard-side
222
+// misconfiguration before it writes the wrong table.
223
+//
224
+// The check requires the handler to know which price-id is Pro and
225
+// which is Team — that wiring lands via BillingPriceIDs(); a
226
+// non-configured client (Pro disabled) skips the check rather than
227
+// rejecting org events. PRO-disabled instances never see Pro
228
+// events, so the org path is unaffected.
229
+func (h *Handlers) guardPriceKindMatch(kind orgbilling.SubjectKind, sub *stripeapi.Subscription) error {
230
+	if sub == nil || sub.Items == nil || len(sub.Items.Data) == 0 || sub.Items.Data[0] == nil || sub.Items.Data[0].Price == nil {
231
+		// No price on the event — nothing to validate. Subsequent
232
+		// apply logic surfaces the missing-data error if needed.
233
+		return nil
234
+	}
235
+	priceID := strings.TrimSpace(sub.Items.Data[0].Price.ID)
236
+	teamPrice, proPrice := h.d.BillingPriceIDs()
237
+	switch kind {
238
+	case orgbilling.SubjectKindOrg:
239
+		if teamPrice != "" && priceID != "" && priceID != teamPrice {
240
+			if priceID == proPrice {
241
+				return fmt.Errorf("stripe subscription: Pro price %q applied to org subject — metadata likely misconfigured", priceID)
242
+			}
243
+			return fmt.Errorf("stripe subscription: price %q does not match expected team price %q for org subject", priceID, teamPrice)
244
+		}
245
+	case orgbilling.SubjectKindUser:
246
+		if proPrice != "" && priceID != "" && priceID != proPrice {
247
+			if priceID == teamPrice {
248
+				return fmt.Errorf("stripe subscription: Team price %q applied to user subject — metadata likely misconfigured", priceID)
249
+			}
250
+			return fmt.Errorf("stripe subscription: price %q does not match expected pro price %q for user subject", priceID, proPrice)
251
+		}
252
+	}
253
+	return nil
254
+}
255
+
150256
 func (h *Handlers) applyStripeInvoiceEvent(ctx context.Context, event stripeapi.Event) error {
151257
 	var inv stripeapi.Invoice
152258
 	if err := unmarshalStripeEventObject(event, &inv); err != nil {
153259
 		return err
154260
 	}
155
-	orgID, state, err := h.resolveOrgStateFromInvoice(ctx, &inv)
261
+	principalState, err := h.resolvePrincipalStateFromInvoice(ctx, &inv)
156262
 	if err != nil {
157263
 		return err
158264
 	}
@@ -160,8 +266,7 @@ func (h *Handlers) applyStripeInvoiceEvent(ctx context.Context, event stripeapi.
160266
 	if err != nil {
161267
 		return err
162268
 	}
163
-	if _, err := orgbilling.UpsertInvoice(ctx, orgbilling.Deps{Pool: h.d.Pool}, orgbilling.InvoiceSnapshot{
164
-		OrgID:                orgID,
269
+	if _, err := orgbilling.UpsertInvoiceForPrincipal(ctx, orgbilling.Deps{Pool: h.d.Pool}, principalState.Principal, orgbilling.InvoiceSnapshot{
165270
 		StripeInvoiceID:      strings.TrimSpace(inv.ID),
166271
 		StripeCustomerID:     stripeCustomerID(inv.Customer),
167272
 		StripeSubscriptionID: stripeInvoiceSubscriptionID(&inv),
@@ -184,64 +289,67 @@ func (h *Handlers) applyStripeInvoiceEvent(ctx context.Context, event stripeapi.
184289
 	switch string(event.Type) {
185290
 	case "invoice.payment_failed":
186291
 		graceUntil := time.Now().UTC().Add(h.d.BillingGracePeriod)
187
-		_, err := orgbilling.MarkPastDue(ctx, orgbilling.Deps{Pool: h.d.Pool}, orgID, graceUntil, event.ID)
292
+		_, err := orgbilling.MarkPastDueForPrincipal(ctx, orgbilling.Deps{Pool: h.d.Pool}, principalState.Principal, graceUntil, event.ID)
188293
 		return err
189294
 	case "invoice.payment_succeeded":
190
-		if state.SubscriptionStatus != orgbilling.SubscriptionStatusCanceled {
191
-			_, err := orgbilling.MarkPaymentSucceeded(ctx, orgbilling.Deps{Pool: h.d.Pool}, orgID, event.ID)
295
+		if principalState.SubscriptionStatus != orgbilling.SubscriptionStatusCanceled {
296
+			_, err := orgbilling.MarkPaymentSucceededForPrincipal(ctx, orgbilling.Deps{Pool: h.d.Pool}, principalState.Principal, event.ID)
192297
 			return err
193298
 		}
194299
 	}
195300
 	return nil
196301
 }
197302
 
198
-func (h *Handlers) resolveOrgIDFromSubscription(ctx context.Context, sub *stripeapi.Subscription) (int64, error) {
199
-	if orgID := stripeOrgIDFromMetadata(sub.Metadata); orgID != 0 {
200
-		return orgID, nil
201
-	}
202
-	if customerID := stripeCustomerID(sub.Customer); customerID != "" {
203
-		state, err := orgbilling.GetOrgBillingStateByStripeCustomer(ctx, orgbilling.Deps{Pool: h.d.Pool}, customerID)
204
-		if err == nil {
205
-			return state.OrgID, nil
206
-		}
207
-		if !errors.Is(err, pgx.ErrNoRows) {
208
-			return 0, err
209
-		}
210
-	}
211
-	if subID := strings.TrimSpace(sub.ID); subID != "" {
212
-		state, err := orgbilling.GetOrgBillingStateByStripeSubscription(ctx, orgbilling.Deps{Pool: h.d.Pool}, subID)
213
-		if err == nil {
214
-			return state.OrgID, nil
215
-		}
216
-		if !errors.Is(err, pgx.ErrNoRows) {
217
-			return 0, err
218
-		}
219
-	}
220
-	return 0, errors.New("stripe subscription does not map to a shithub organization")
221
-}
222
-
223
-func (h *Handlers) resolveOrgStateFromInvoice(ctx context.Context, inv *stripeapi.Invoice) (int64, orgbilling.State, error) {
303
+// resolvePrincipalStateFromInvoice resolves Principal AND fetches
304
+// the current billing state in one shot — the apply branch needs
305
+// the SubscriptionStatus to decide whether to flip payment-
306
+// succeeded transitions. Mirrors the legacy
307
+// resolveOrgStateFromInvoice but returns a kind-tagged Principal.
308
+func (h *Handlers) resolvePrincipalStateFromInvoice(ctx context.Context, inv *stripeapi.Invoice) (orgbilling.PrincipalState, error) {
224309
 	if customerID := stripeCustomerID(inv.Customer); customerID != "" {
225
-		state, err := orgbilling.GetOrgBillingStateByStripeCustomer(ctx, orgbilling.Deps{Pool: h.d.Pool}, customerID)
310
+		state, err := orgbilling.ResolvePrincipalByStripeCustomer(ctx, orgbilling.Deps{Pool: h.d.Pool}, customerID)
226311
 		if err == nil {
227
-			return state.OrgID, state, nil
312
+			return state, nil
228313
 		}
229
-		if !errors.Is(err, pgx.ErrNoRows) {
230
-			return 0, orgbilling.State{}, err
314
+		if !errors.Is(err, orgbilling.ErrPrincipalNotFound) {
315
+			return orgbilling.PrincipalState{}, err
231316
 		}
232317
 	}
233318
 	if subID := stripeInvoiceSubscriptionID(inv); subID != "" {
234
-		state, err := orgbilling.GetOrgBillingStateByStripeSubscription(ctx, orgbilling.Deps{Pool: h.d.Pool}, subID)
319
+		state, err := orgbilling.ResolvePrincipalByStripeSubscription(ctx, orgbilling.Deps{Pool: h.d.Pool}, subID)
235320
 		if err == nil {
236
-			return state.OrgID, state, nil
321
+			return state, nil
237322
 		}
238
-		if !errors.Is(err, pgx.ErrNoRows) {
239
-			return 0, orgbilling.State{}, err
323
+		if !errors.Is(err, orgbilling.ErrPrincipalNotFound) {
324
+			return orgbilling.PrincipalState{}, err
240325
 		}
241326
 	}
242
-	return 0, orgbilling.State{}, errors.New("stripe invoice does not map to a shithub organization")
327
+	return orgbilling.PrincipalState{}, errors.New("stripe invoice does not map to a shithub subject")
328
+}
329
+
330
+// stripePrincipalFromMetadata reads the PRO04 subject metadata
331
+// keys. Returns ok=false when either key is missing or malformed —
332
+// the caller falls through to the legacy resolution chain.
333
+func stripePrincipalFromMetadata(metadata map[string]string) (orgbilling.Principal, bool) {
334
+	if len(metadata) == 0 {
335
+		return orgbilling.Principal{}, false
336
+	}
337
+	kind := orgbilling.SubjectKind(strings.TrimSpace(metadata[stripebilling.MetadataSubjectKind]))
338
+	if !kind.Valid() {
339
+		return orgbilling.Principal{}, false
340
+	}
341
+	rawID := strings.TrimSpace(metadata[stripebilling.MetadataSubjectID])
342
+	id, err := strconv.ParseInt(rawID, 10, 64)
343
+	if err != nil || id <= 0 {
344
+		return orgbilling.Principal{}, false
345
+	}
346
+	return orgbilling.Principal{Kind: kind, ID: id}, true
243347
 }
244348
 
349
+// stripeOrgIDFromMetadata reads the legacy SP03 metadata key.
350
+// PRO04 keeps it for backward compatibility — existing org
351
+// subscriptions stamped before PRO04 deployed carry only this
352
+// key. Resolvers try the PRO04 keys first, fall back to this.
245353
 func stripeOrgIDFromMetadata(metadata map[string]string) int64 {
246354
 	raw := strings.TrimSpace(metadata[stripebilling.MetadataOrgID])
247355
 	if raw == "" {
internal/web/handlers/orgs/billing_webhook_resolve_test.goadded
@@ -0,0 +1,125 @@
1
+// SPDX-License-Identifier: AGPL-3.0-or-later
2
+
3
+package orgs
4
+
5
+import (
6
+	"testing"
7
+
8
+	orgbilling "github.com/tenseleyFlow/shithub/internal/billing"
9
+	"github.com/tenseleyFlow/shithub/internal/billing/stripebilling"
10
+)
11
+
12
+func TestStripePrincipalFromMetadataPRO04Keys(t *testing.T) {
13
+	t.Parallel()
14
+	m := map[string]string{
15
+		stripebilling.MetadataSubjectKind: "user",
16
+		stripebilling.MetadataSubjectID:   "42",
17
+	}
18
+	p, ok := stripePrincipalFromMetadata(m)
19
+	if !ok {
20
+		t.Fatal("expected ok=true for valid PRO04 metadata")
21
+	}
22
+	if !p.IsUser() || p.ID != 42 {
23
+		t.Errorf("user principal mismatch: %+v", p)
24
+	}
25
+}
26
+
27
+func TestStripePrincipalFromMetadataOrgKind(t *testing.T) {
28
+	t.Parallel()
29
+	m := map[string]string{
30
+		stripebilling.MetadataSubjectKind: "org",
31
+		stripebilling.MetadataSubjectID:   "7",
32
+	}
33
+	p, ok := stripePrincipalFromMetadata(m)
34
+	if !ok {
35
+		t.Fatal("expected ok=true for org PRO04 metadata")
36
+	}
37
+	if !p.IsOrg() || p.ID != 7 {
38
+		t.Errorf("org principal mismatch: %+v", p)
39
+	}
40
+}
41
+
42
+func TestStripePrincipalFromMetadataRejectsMissingKeys(t *testing.T) {
43
+	t.Parallel()
44
+	cases := []map[string]string{
45
+		nil,
46
+		{},
47
+		// Missing id.
48
+		{stripebilling.MetadataSubjectKind: "user"},
49
+		// Missing kind.
50
+		{stripebilling.MetadataSubjectID: "1"},
51
+		// Bogus kind.
52
+		{
53
+			stripebilling.MetadataSubjectKind: "alien",
54
+			stripebilling.MetadataSubjectID:   "1",
55
+		},
56
+		// Non-integer id.
57
+		{
58
+			stripebilling.MetadataSubjectKind: "user",
59
+			stripebilling.MetadataSubjectID:   "abc",
60
+		},
61
+		// Zero/negative id.
62
+		{
63
+			stripebilling.MetadataSubjectKind: "user",
64
+			stripebilling.MetadataSubjectID:   "0",
65
+		},
66
+		{
67
+			stripebilling.MetadataSubjectKind: "user",
68
+			stripebilling.MetadataSubjectID:   "-1",
69
+		},
70
+	}
71
+	for i, m := range cases {
72
+		if _, ok := stripePrincipalFromMetadata(m); ok {
73
+			t.Errorf("case %d: expected ok=false for %+v", i, m)
74
+		}
75
+	}
76
+}
77
+
78
+func TestStripeOrgIDFromMetadataLegacyKey(t *testing.T) {
79
+	t.Parallel()
80
+	m := map[string]string{stripebilling.MetadataOrgID: "99"}
81
+	if got := stripeOrgIDFromMetadata(m); got != 99 {
82
+		t.Errorf("legacy org id: got %d, want 99", got)
83
+	}
84
+	// Missing / bad ones return 0.
85
+	for _, m := range []map[string]string{
86
+		nil,
87
+		{},
88
+		{stripebilling.MetadataOrgID: ""},
89
+		{stripebilling.MetadataOrgID: "abc"},
90
+		{stripebilling.MetadataOrgID: "0"},
91
+		{stripebilling.MetadataOrgID: "-1"},
92
+	} {
93
+		if got := stripeOrgIDFromMetadata(m); got != 0 {
94
+			t.Errorf("expected 0 for %+v, got %d", m, got)
95
+		}
96
+	}
97
+}
98
+
99
+// Make sure the resolver chain prefers PRO04 keys over legacy when
100
+// both are present — defends against an old org's metadata being
101
+// re-stamped during Pro adoption.
102
+func TestStripePrincipalPRO04KeysPreferredOverLegacy(t *testing.T) {
103
+	t.Parallel()
104
+	m := map[string]string{
105
+		stripebilling.MetadataSubjectKind: "user",
106
+		stripebilling.MetadataSubjectID:   "42",
107
+		stripebilling.MetadataOrgID:       "7", // would route to org if used
108
+	}
109
+	p, ok := stripePrincipalFromMetadata(m)
110
+	if !ok || !p.IsUser() || p.ID != 42 {
111
+		t.Errorf("PRO04 user keys not preferred: got %+v ok=%v", p, ok)
112
+	}
113
+}
114
+
115
+// Mismatch guard: when the orchestrator constructs a Principal,
116
+// the kind and id must survive String() formatting (used in logs).
117
+func TestPrincipalStringRoundTrip(t *testing.T) {
118
+	t.Parallel()
119
+	if got := orgbilling.PrincipalForUser(99).String(); got != "user:99" {
120
+		t.Errorf("user string: %q", got)
121
+	}
122
+	if got := orgbilling.PrincipalForOrg(99).String(); got != "org:99" {
123
+		t.Errorf("org string: %q", got)
124
+	}
125
+}