Go · 35611 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 package orgs_test
4
5 import (
6 "context"
7 "encoding/json"
8 "io"
9 "log/slog"
10 "net/http"
11 "net/http/httptest"
12 "net/url"
13 "strconv"
14 "strings"
15 "testing"
16 "testing/fstest"
17 "time"
18
19 "github.com/go-chi/chi/v5"
20 "github.com/jackc/pgx/v5/pgxpool"
21 stripeapi "github.com/stripe/stripe-go/v85"
22
23 orgbilling "github.com/tenseleyFlow/shithub/internal/billing"
24 billingdb "github.com/tenseleyFlow/shithub/internal/billing/sqlc"
25 "github.com/tenseleyFlow/shithub/internal/billing/stripebilling"
26 "github.com/tenseleyFlow/shithub/internal/testing/dbtest"
27 orgsh "github.com/tenseleyFlow/shithub/internal/web/handlers/orgs"
28 "github.com/tenseleyFlow/shithub/internal/web/middleware"
29 "github.com/tenseleyFlow/shithub/internal/web/render"
30 )
31
32 func TestOrgBillingCheckoutRedirectsToStripeAndCreatesCustomer(t *testing.T) {
33 t.Parallel()
34 ctx := context.Background()
35 pool := dbtest.NewTestDB(t)
36 ownerID := insertOrgAvatarUser(t, pool, "owner")
37 orgID := insertOrgAvatarOrg(t, pool, ownerID, "acme")
38 fake := &fakeStripeRemote{
39 createCustomerFn: func(_ context.Context, in stripebilling.CustomerInput) (stripebilling.Customer, error) {
40 if in.OrgID != orgID || in.OrgSlug != "acme" {
41 t.Fatalf("unexpected customer input: %+v", in)
42 }
43 return stripebilling.Customer{ID: "cus_test_checkout"}, nil
44 },
45 createCheckoutFn: func(_ context.Context, in stripebilling.CheckoutInput) (stripebilling.CheckoutSession, error) {
46 if in.CustomerID != "cus_test_checkout" {
47 t.Fatalf("checkout customer = %q", in.CustomerID)
48 }
49 if in.SeatCount != 1 {
50 t.Fatalf("checkout seats = %d, want 1", in.SeatCount)
51 }
52 if !strings.Contains(in.SuccessURL, "/organizations/acme/billing/success") {
53 t.Fatalf("success url = %q", in.SuccessURL)
54 }
55 if !strings.Contains(in.CancelURL, "/organizations/acme/billing/cancel") {
56 t.Fatalf("cancel url = %q", in.CancelURL)
57 }
58 return stripebilling.CheckoutSession{ID: "cs_test", URL: "https://checkout.stripe.test/session"}, nil
59 },
60 }
61 mux := newOrgBillingMux(t, pool, ownerID, fake)
62
63 resp := httptest.NewRecorder()
64 req := newOrgFormRequest(http.MethodPost, "/organizations/acme/billing/checkout", url.Values{})
65 mux.ServeHTTP(resp, req)
66 if resp.Code != http.StatusSeeOther {
67 t.Fatalf("checkout status=%d body=%s", resp.Code, resp.Body.String())
68 }
69 if got := resp.Header().Get("Location"); got != "https://checkout.stripe.test/session" {
70 t.Fatalf("checkout redirect=%q", got)
71 }
72 state, err := orgbilling.GetOrgBillingState(ctx, orgbilling.Deps{Pool: pool}, orgID)
73 if err != nil {
74 t.Fatalf("GetOrgBillingState: %v", err)
75 }
76 if !state.StripeCustomerID.Valid || state.StripeCustomerID.String != "cus_test_checkout" {
77 t.Fatalf("expected stripe customer saved, got %+v", state.StripeCustomerID)
78 }
79 }
80
81 func TestOrgBillingPortalRedirectsToStripe(t *testing.T) {
82 t.Parallel()
83 ctx := context.Background()
84 pool := dbtest.NewTestDB(t)
85 ownerID := insertOrgAvatarUser(t, pool, "owner")
86 orgID := insertOrgAvatarOrg(t, pool, ownerID, "acme")
87 if _, err := orgbilling.SetStripeCustomer(ctx, orgbilling.Deps{Pool: pool}, orgID, "cus_test_portal"); err != nil {
88 t.Fatalf("SetStripeCustomer: %v", err)
89 }
90 fake := &fakeStripeRemote{
91 createPortalFn: func(_ context.Context, in stripebilling.PortalInput) (stripebilling.PortalSession, error) {
92 if in.CustomerID != "cus_test_portal" {
93 t.Fatalf("portal customer = %q", in.CustomerID)
94 }
95 if !strings.Contains(in.ReturnURL, "/organizations/acme/settings/billing") {
96 t.Fatalf("portal return url = %q", in.ReturnURL)
97 }
98 return stripebilling.PortalSession{ID: "bps_test", URL: "https://billing.stripe.test/session"}, nil
99 },
100 }
101 mux := newOrgBillingMux(t, pool, ownerID, fake)
102
103 resp := httptest.NewRecorder()
104 req := newOrgFormRequest(http.MethodPost, "/organizations/acme/billing/portal", url.Values{})
105 mux.ServeHTTP(resp, req)
106 if resp.Code != http.StatusSeeOther {
107 t.Fatalf("portal status=%d body=%s", resp.Code, resp.Body.String())
108 }
109 if got := resp.Header().Get("Location"); got != "https://billing.stripe.test/session" {
110 t.Fatalf("portal redirect=%q", got)
111 }
112 }
113
114 func TestOrgBillingResultPagesRenderPostCheckoutState(t *testing.T) {
115 t.Parallel()
116 pool := dbtest.NewTestDB(t)
117 ownerID := insertOrgAvatarUser(t, pool, "owner")
118 insertOrgAvatarOrg(t, pool, ownerID, "acme")
119 mux := newOrgBillingMux(t, pool, ownerID, &fakeStripeRemote{})
120
121 for _, tc := range []struct {
122 path string
123 want string
124 }{
125 {path: "/organizations/acme/billing/success", want: "RESULT=success;HEADING=Checkout complete"},
126 {path: "/organizations/acme/billing/cancel", want: "RESULT=canceled;HEADING=Checkout canceled"},
127 } {
128 resp := httptest.NewRecorder()
129 req := newOrgFormRequest(http.MethodGet, tc.path, nil)
130 mux.ServeHTTP(resp, req)
131 if resp.Code != http.StatusOK {
132 t.Fatalf("%s status=%d body=%s", tc.path, resp.Code, resp.Body.String())
133 }
134 if !strings.Contains(resp.Body.String(), tc.want) {
135 t.Fatalf("%s missing %q in body %s", tc.path, tc.want, resp.Body.String())
136 }
137 }
138 }
139
140 func TestOrgBillingSettingsRequiresOwner(t *testing.T) {
141 t.Parallel()
142 pool := dbtest.NewTestDB(t)
143 ownerID := insertOrgAvatarUser(t, pool, "owner")
144 insertOrgAvatarOrg(t, pool, ownerID, "acme")
145 memberID := insertOrgAvatarUser(t, pool, "member")
146 mux := newOrgBillingMuxForUser(t, pool, middleware.CurrentUser{ID: memberID, Username: "member"}, &fakeStripeRemote{})
147
148 resp := httptest.NewRecorder()
149 req := newOrgFormRequest(http.MethodGet, "/organizations/acme/settings/billing", nil)
150 mux.ServeHTTP(resp, req)
151 if resp.Code != http.StatusForbidden {
152 t.Fatalf("settings status=%d body=%s", resp.Code, resp.Body.String())
153 }
154 }
155
156 func TestOrgBillingSettingsRendersSeatBreakdownAndHidesStripeIDs(t *testing.T) {
157 t.Parallel()
158 ctx := context.Background()
159 pool := dbtest.NewTestDB(t)
160 ownerID := insertOrgAvatarUser(t, pool, "owner")
161 orgID := insertOrgAvatarOrg(t, pool, ownerID, "acme")
162 if _, err := orgbilling.SetStripeCustomer(ctx, orgbilling.Deps{Pool: pool}, orgID, "cus_owner_secret"); err != nil {
163 t.Fatalf("SetStripeCustomer: %v", err)
164 }
165 if _, err := orgbilling.SyncSeatSnapshot(ctx, orgbilling.Deps{Pool: pool}, orgbilling.SeatSnapshot{
166 OrgID: orgID,
167 ActiveMembers: 3,
168 BillableSeats: 3,
169 Source: "test",
170 }); err != nil {
171 t.Fatalf("SyncSeatSnapshot: %v", err)
172 }
173 insertBillingPendingInvitation(t, pool, orgID, "pending@example.com", []byte{1, 2, 3})
174 mux := newOrgBillingMux(t, pool, ownerID, &fakeStripeRemote{})
175
176 resp := httptest.NewRecorder()
177 req := newOrgFormRequest(http.MethodGet, "/organizations/acme/settings/billing", nil)
178 mux.ServeHTTP(resp, req)
179 body := resp.Body.String()
180 if resp.Code != http.StatusOK {
181 t.Fatalf("settings status=%d body=%s", resp.Code, body)
182 }
183 if !strings.Contains(body, "SEATS=1/3/1;") {
184 t.Fatalf("settings did not render seat breakdown: %s", body)
185 }
186 if strings.Contains(body, "cus_owner_secret") {
187 t.Fatalf("owner billing page leaked Stripe customer id: %s", body)
188 }
189 }
190
191 func TestOrgBillingSettingsRendersUsageThresholds(t *testing.T) {
192 t.Parallel()
193 ctx := context.Background()
194 pool := dbtest.NewTestDB(t)
195 ownerID := insertOrgAvatarUser(t, pool, "owner")
196 orgID := insertOrgAvatarOrg(t, pool, ownerID, "acme")
197 start, end := orgbilling.MonthlyUsagePeriod(time.Date(2026, 5, 13, 12, 0, 0, 0, time.UTC))
198 if _, err := orgbilling.UpsertOrgUsageCounters(ctx, orgbilling.Deps{Pool: pool}, orgbilling.UsageCounterSnapshot{
199 OrgID: orgID,
200 RepoStorageBytes: 490 * 1024 * 1024,
201 ObjectStorageBytes: 0,
202 ActionsLogBytes: 0,
203 ActionsArtifactBytes: 0,
204 ActionsMinutesUsed: 1600,
205 ActionsPeriodStart: start,
206 ActionsPeriodEnd: end,
207 CalculatedAt: start.Add(12 * time.Hour),
208 }); err != nil {
209 t.Fatalf("UpsertOrgUsageCounters: %v", err)
210 }
211 mux := newOrgBillingMux(t, pool, ownerID, &fakeStripeRemote{})
212
213 resp := httptest.NewRecorder()
214 req := newOrgFormRequest(http.MethodGet, "/organizations/acme/settings/billing", nil)
215 mux.ServeHTTP(resp, req)
216 body := resp.Body.String()
217 if resp.Code != http.StatusOK {
218 t.Fatalf("settings status=%d body=%s", resp.Code, body)
219 }
220 if !strings.Contains(body, "USAGE=storage:490 MiB/500 MiB/98%/is-danger;") {
221 t.Fatalf("settings did not render storage usage threshold: %s", body)
222 }
223 if !strings.Contains(body, "USAGE=actions-minutes:1600 minutes/2000 minutes/80%/is-warning;") {
224 t.Fatalf("settings did not render actions usage threshold: %s", body)
225 }
226 if !strings.Contains(body, "USAGE_ALERT=This organization has used at least 95% of its storage quota.") {
227 t.Fatalf("settings did not render quota warning: %s", body)
228 }
229 }
230
231 func TestOrgBillingSettingsAppliesQuotaOverrides(t *testing.T) {
232 t.Parallel()
233 ctx := context.Background()
234 pool := dbtest.NewTestDB(t)
235 ownerID := insertOrgAvatarUser(t, pool, "owner")
236 orgID := insertOrgAvatarOrg(t, pool, ownerID, "acme")
237 start, end := orgbilling.MonthlyUsagePeriod(time.Date(2026, 5, 13, 12, 0, 0, 0, time.UTC))
238 deps := orgbilling.Deps{Pool: pool}
239 if _, err := orgbilling.UpsertOrgUsageCounters(ctx, deps, orgbilling.UsageCounterSnapshot{
240 OrgID: orgID,
241 RepoStorageBytes: 6 * 1024 * 1024 * 1024,
242 ObjectStorageBytes: 0,
243 ActionsLogBytes: 0,
244 ActionsArtifactBytes: 0,
245 ActionsMinutesUsed: 2100,
246 ActionsPeriodStart: start,
247 ActionsPeriodEnd: end,
248 CalculatedAt: start.Add(12 * time.Hour),
249 }); err != nil {
250 t.Fatalf("UpsertOrgUsageCounters: %v", err)
251 }
252 if _, err := orgbilling.UpsertOrgQuotaOverride(ctx, deps, orgbilling.QuotaOverrideInput{
253 OrgID: orgID,
254 Kind: orgbilling.QuotaKindStorageBytes,
255 LimitValue: 10 * 1024 * 1024 * 1024,
256 Note: "temporary migration",
257 CreatedByUserID: ownerID,
258 }); err != nil {
259 t.Fatalf("UpsertOrgQuotaOverride storage: %v", err)
260 }
261 if _, err := orgbilling.UpsertOrgQuotaOverride(ctx, deps, orgbilling.QuotaOverrideInput{
262 OrgID: orgID,
263 Kind: orgbilling.QuotaKindActionsMinutes,
264 Unlimited: true,
265 CreatedByUserID: ownerID,
266 }); err != nil {
267 t.Fatalf("UpsertOrgQuotaOverride actions: %v", err)
268 }
269 mux := newOrgBillingMuxForUser(t, pool, middleware.CurrentUser{ID: ownerID, Username: "owner", IsSiteAdmin: true}, &fakeStripeRemote{})
270
271 resp := httptest.NewRecorder()
272 req := newOrgFormRequest(http.MethodGet, "/organizations/acme/settings/billing", nil)
273 mux.ServeHTTP(resp, req)
274 body := resp.Body.String()
275 if resp.Code != http.StatusOK {
276 t.Fatalf("settings status=%d body=%s", resp.Code, body)
277 }
278 if !strings.Contains(body, "USAGE=storage:6 GiB/10 GiB override/60%/is-ok;") {
279 t.Fatalf("settings did not apply storage override: %s", body)
280 }
281 if !strings.Contains(body, "USAGE=actions-minutes:2100 minutes/Unlimited override/Unlimited/is-ok;") {
282 t.Fatalf("settings did not apply actions override: %s", body)
283 }
284 if !strings.Contains(body, "OVERRIDE=Storage:10 GiB;") || !strings.Contains(body, "OVERRIDE=Actions minutes:Unlimited;") {
285 t.Fatalf("settings did not render site-admin quota overrides: %s", body)
286 }
287 }
288
289 func TestOrgBillingSiteAdminCanManageQuotaOverrides(t *testing.T) {
290 t.Parallel()
291 ctx := context.Background()
292 pool := dbtest.NewTestDB(t)
293 ownerID := insertOrgAvatarUser(t, pool, "owner")
294 adminID := insertOrgAvatarUser(t, pool, "admin")
295 orgID := insertOrgAvatarOrg(t, pool, ownerID, "acme")
296 mux := newOrgBillingMuxForUser(t, pool, middleware.CurrentUser{ID: adminID, Username: "admin", IsSiteAdmin: true}, &fakeStripeRemote{})
297
298 resp := httptest.NewRecorder()
299 req := newOrgFormRequest(http.MethodPost, "/organizations/acme/billing/quota-overrides", url.Values{
300 "kind": {"storage_bytes"},
301 "limit_value": {"1073741824"},
302 "note": {"support migration"},
303 })
304 mux.ServeHTTP(resp, req)
305 if resp.Code != http.StatusSeeOther {
306 t.Fatalf("save status=%d body=%s", resp.Code, resp.Body.String())
307 }
308 if got := resp.Header().Get("Location"); got != "/organizations/acme/settings/billing?notice=quota-override-saved" {
309 t.Fatalf("save redirect=%q", got)
310 }
311 override, err := orgbilling.GetOrgQuotaOverride(ctx, orgbilling.Deps{Pool: pool}, orgID, orgbilling.QuotaKindStorageBytes)
312 if err != nil {
313 t.Fatalf("GetOrgQuotaOverride: %v", err)
314 }
315 if override.Unlimited || !override.LimitValue.Valid || override.LimitValue.Int64 != 1073741824 || override.Note != "support migration" {
316 t.Fatalf("unexpected override: %+v", override)
317 }
318 if !override.CreatedByUserID.Valid || override.CreatedByUserID.Int64 != adminID {
319 t.Fatalf("created_by_user_id=%+v, want %d", override.CreatedByUserID, adminID)
320 }
321
322 resp = httptest.NewRecorder()
323 req = newOrgFormRequest(http.MethodPost, "/organizations/acme/billing/quota-overrides", url.Values{
324 "kind": {"actions_minutes"},
325 "unlimited": {"1"},
326 })
327 mux.ServeHTTP(resp, req)
328 if resp.Code != http.StatusSeeOther {
329 t.Fatalf("unlimited save status=%d body=%s", resp.Code, resp.Body.String())
330 }
331 actionsOverride, err := orgbilling.GetOrgQuotaOverride(ctx, orgbilling.Deps{Pool: pool}, orgID, orgbilling.QuotaKindActionsMinutes)
332 if err != nil {
333 t.Fatalf("GetOrgQuotaOverride actions: %v", err)
334 }
335 if !actionsOverride.Unlimited || actionsOverride.LimitValue.Valid {
336 t.Fatalf("unexpected unlimited override: %+v", actionsOverride)
337 }
338
339 resp = httptest.NewRecorder()
340 req = newOrgFormRequest(http.MethodPost, "/organizations/acme/billing/quota-overrides/delete", url.Values{
341 "kind": {"storage_bytes"},
342 })
343 mux.ServeHTTP(resp, req)
344 if resp.Code != http.StatusSeeOther {
345 t.Fatalf("delete status=%d body=%s", resp.Code, resp.Body.String())
346 }
347 if got := resp.Header().Get("Location"); got != "/organizations/acme/settings/billing?notice=quota-override-cleared" {
348 t.Fatalf("delete redirect=%q", got)
349 }
350 overrides, err := orgbilling.ListOrgQuotaOverrides(ctx, orgbilling.Deps{Pool: pool}, orgID)
351 if err != nil {
352 t.Fatalf("ListOrgQuotaOverrides: %v", err)
353 }
354 if len(overrides) != 1 || overrides[0].Kind != orgbilling.QuotaKindActionsMinutes {
355 t.Fatalf("unexpected remaining overrides: %+v", overrides)
356 }
357 }
358
359 func TestOrgBillingQuotaOverridesRequireSiteAdmin(t *testing.T) {
360 t.Parallel()
361 ctx := context.Background()
362 pool := dbtest.NewTestDB(t)
363 ownerID := insertOrgAvatarUser(t, pool, "owner")
364 orgID := insertOrgAvatarOrg(t, pool, ownerID, "acme")
365 mux := newOrgBillingMux(t, pool, ownerID, &fakeStripeRemote{})
366
367 resp := httptest.NewRecorder()
368 req := newOrgFormRequest(http.MethodPost, "/organizations/acme/billing/quota-overrides", url.Values{
369 "kind": {"storage_bytes"},
370 "limit_value": {"1"},
371 })
372 mux.ServeHTTP(resp, req)
373 if resp.Code != http.StatusNotFound {
374 t.Fatalf("save status=%d body=%s", resp.Code, resp.Body.String())
375 }
376 overrides, err := orgbilling.ListOrgQuotaOverrides(ctx, orgbilling.Deps{Pool: pool}, orgID)
377 if err != nil {
378 t.Fatalf("ListOrgQuotaOverrides: %v", err)
379 }
380 if len(overrides) != 0 {
381 t.Fatalf("non-admin created overrides: %+v", overrides)
382 }
383 }
384
385 func TestOrgBillingSettingsSiteAdminDebugShowsProviderState(t *testing.T) {
386 t.Parallel()
387 ctx := context.Background()
388 pool := dbtest.NewTestDB(t)
389 ownerID := insertOrgAvatarUser(t, pool, "owner")
390 orgID := insertOrgAvatarOrg(t, pool, ownerID, "acme")
391 deps := orgbilling.Deps{Pool: pool}
392 if _, err := orgbilling.SetStripeCustomer(ctx, deps, orgID, "cus_debug"); err != nil {
393 t.Fatalf("SetStripeCustomer: %v", err)
394 }
395 if _, err := orgbilling.ApplySubscriptionSnapshot(ctx, deps, orgbilling.SubscriptionSnapshot{
396 OrgID: orgID,
397 Plan: orgbilling.PlanTeam,
398 Status: orgbilling.SubscriptionStatusActive,
399 StripeSubscriptionID: "sub_debug",
400 StripeSubscriptionItemID: "si_debug",
401 CurrentPeriodStart: time.Now().UTC().Add(-time.Hour),
402 CurrentPeriodEnd: time.Now().UTC().Add(30 * 24 * time.Hour),
403 LastWebhookEventID: "evt_debug",
404 }); err != nil {
405 t.Fatalf("ApplySubscriptionSnapshot: %v", err)
406 }
407 if _, _, err := orgbilling.RecordWebhookEvent(ctx, deps, orgbilling.WebhookEvent{
408 ProviderEventID: "evt_debug",
409 EventType: "customer.subscription.updated",
410 APIVersion: "2024-06-20",
411 Payload: []byte(`{"id":"evt_debug"}`),
412 }); err != nil {
413 t.Fatalf("RecordWebhookEvent: %v", err)
414 }
415 if _, err := orgbilling.MarkWebhookEventProcessed(ctx, deps, "evt_debug"); err != nil {
416 t.Fatalf("MarkWebhookEventProcessed: %v", err)
417 }
418 mux := newOrgBillingMuxForUser(t, pool, middleware.CurrentUser{ID: ownerID, Username: "owner", IsSiteAdmin: true}, &fakeStripeRemote{})
419
420 resp := httptest.NewRecorder()
421 req := newOrgFormRequest(http.MethodGet, "/organizations/acme/settings/billing", nil)
422 mux.ServeHTTP(resp, req)
423 body := resp.Body.String()
424 if resp.Code != http.StatusOK {
425 t.Fatalf("settings status=%d body=%s", resp.Code, body)
426 }
427 if !strings.Contains(body, "DEBUG=cus_debug|sub_debug|si_debug|evt_debug|processed;") {
428 t.Fatalf("settings did not render site-admin debug state: %s", body)
429 }
430 }
431
432 func TestOrgBillingSettingsShowsPastDueAlert(t *testing.T) {
433 t.Parallel()
434 ctx := context.Background()
435 pool := dbtest.NewTestDB(t)
436 ownerID := insertOrgAvatarUser(t, pool, "owner")
437 orgID := insertOrgAvatarOrg(t, pool, ownerID, "acme")
438 if _, err := orgbilling.MarkPastDue(ctx, orgbilling.Deps{Pool: pool}, orgID, time.Now().UTC().Add(24*time.Hour), "evt_failed"); err != nil {
439 t.Fatalf("MarkPastDue: %v", err)
440 }
441 mux := newOrgBillingMux(t, pool, ownerID, &fakeStripeRemote{})
442
443 resp := httptest.NewRecorder()
444 req := newOrgFormRequest(http.MethodGet, "/organizations/acme/settings/billing", nil)
445 mux.ServeHTTP(resp, req)
446 body := resp.Body.String()
447 if resp.Code != http.StatusOK {
448 t.Fatalf("settings status=%d body=%s", resp.Code, body)
449 }
450 if !strings.Contains(body, "ALERT=Payment failed.") {
451 t.Fatalf("settings did not render past-due alert: %s", body)
452 }
453 }
454
455 func TestOrgBillingWebhookProcessesSubscriptionAndStaysIdempotent(t *testing.T) {
456 t.Parallel()
457 ctx := context.Background()
458 pool := dbtest.NewTestDB(t)
459 ownerID := insertOrgAvatarUser(t, pool, "owner")
460 orgID := insertOrgAvatarOrg(t, pool, ownerID, "acme")
461 raw, err := json.Marshal(map[string]any{
462 "id": "sub_test",
463 "customer": "cus_test_webhook",
464 "status": "active",
465 "cancel_at_period_end": false,
466 "trial_end": int64(0),
467 "canceled_at": int64(0),
468 "metadata": map[string]string{stripebilling.MetadataOrgID: strconv.FormatInt(orgID, 10)},
469 "items": map[string]any{"data": []map[string]any{{
470 "id": "si_test_webhook",
471 "current_period_start": time.Now().UTC().Add(-time.Hour).Unix(),
472 "current_period_end": time.Now().UTC().Add(30 * 24 * time.Hour).Unix(),
473 }}},
474 })
475 if err != nil {
476 t.Fatalf("marshal subscription raw: %v", err)
477 }
478 fake := &fakeStripeRemote{
479 verifyWebhookFn: func(_ []byte, _ string) (stripeapi.Event, error) {
480 return stripeapi.Event{
481 ID: "evt_sub_active",
482 Type: stripeapi.EventType("customer.subscription.updated"),
483 APIVersion: "2024-06-20",
484 Data: &stripeapi.EventData{Raw: raw},
485 }, nil
486 },
487 }
488 mux := newOrgBillingMux(t, pool, ownerID, fake)
489
490 req := httptest.NewRequest(http.MethodPost, "/stripe/webhook", strings.NewReader(`{"id":"evt_sub_active"}`))
491 req.Header.Set("Stripe-Signature", "sig_test")
492 resp := httptest.NewRecorder()
493 mux.ServeHTTP(resp, req)
494 if resp.Code != http.StatusOK {
495 t.Fatalf("first webhook status=%d body=%s", resp.Code, resp.Body.String())
496 }
497 state, err := orgbilling.GetOrgBillingState(ctx, orgbilling.Deps{Pool: pool}, orgID)
498 if err != nil {
499 t.Fatalf("GetOrgBillingState: %v", err)
500 }
501 if state.Plan != orgbilling.PlanTeam || state.SubscriptionStatus != orgbilling.SubscriptionStatusActive {
502 t.Fatalf("unexpected billing state: %+v", state)
503 }
504 if !state.StripeCustomerID.Valid || state.StripeCustomerID.String != "cus_test_webhook" {
505 t.Fatalf("expected customer id saved, got %+v", state.StripeCustomerID)
506 }
507 if !state.StripeSubscriptionID.Valid || state.StripeSubscriptionID.String != "sub_test" {
508 t.Fatalf("expected subscription id saved, got %+v", state.StripeSubscriptionID)
509 }
510 receipt, err := billingdb.New().GetWebhookEventReceipt(ctx, pool, "evt_sub_active")
511 if err != nil {
512 t.Fatalf("GetWebhookEventReceipt: %v", err)
513 }
514 if !receipt.ProcessedAt.Valid || receipt.ProcessingAttempts != 1 {
515 t.Fatalf("unexpected receipt after first processing: %+v", receipt)
516 }
517 // PRO08 A2: subject must be recorded on the receipt after resolve.
518 if !receipt.SubjectKind.Valid || receipt.SubjectKind.BillingSubjectKind != billingdb.BillingSubjectKindOrg {
519 t.Fatalf("receipt subject_kind: got %+v, want org", receipt.SubjectKind)
520 }
521 if !receipt.SubjectID.Valid || receipt.SubjectID.Int64 != orgID {
522 t.Fatalf("receipt subject_id: got %+v, want %d", receipt.SubjectID, orgID)
523 }
524
525 req = httptest.NewRequest(http.MethodPost, "/stripe/webhook", strings.NewReader(`{"id":"evt_sub_active"}`))
526 req.Header.Set("Stripe-Signature", "sig_test")
527 resp = httptest.NewRecorder()
528 mux.ServeHTTP(resp, req)
529 if resp.Code != http.StatusOK {
530 t.Fatalf("duplicate webhook status=%d body=%s", resp.Code, resp.Body.String())
531 }
532 receipt, err = billingdb.New().GetWebhookEventReceipt(ctx, pool, "evt_sub_active")
533 if err != nil {
534 t.Fatalf("GetWebhookEventReceipt duplicate: %v", err)
535 }
536 if receipt.ProcessingAttempts != 1 {
537 t.Fatalf("duplicate webhook should not reprocess receipt: %+v", receipt)
538 }
539 }
540
541 func TestOrgBillingWebhookCheckoutCompletedStoresCustomerOnly(t *testing.T) {
542 t.Parallel()
543 ctx := context.Background()
544 pool := dbtest.NewTestDB(t)
545 ownerID := insertOrgAvatarUser(t, pool, "owner")
546 orgID := insertOrgAvatarOrg(t, pool, ownerID, "acme")
547 raw := mustJSONRaw(t, map[string]any{
548 "id": "cs_test_completed",
549 "object": "checkout.session",
550 "customer": "cus_test_checkout_completed",
551 "client_reference_id": strconv.FormatInt(orgID, 10),
552 })
553 fake := &fakeStripeRemote{
554 verifyWebhookFn: func(_ []byte, _ string) (stripeapi.Event, error) {
555 return stripeapi.Event{
556 ID: "evt_checkout_completed",
557 Type: stripeapi.EventType("checkout.session.completed"),
558 APIVersion: "2024-06-20",
559 Data: &stripeapi.EventData{Raw: raw},
560 }, nil
561 },
562 }
563 mux := newOrgBillingMux(t, pool, ownerID, fake)
564
565 resp := postBillingWebhook(t, mux, "evt_checkout_completed")
566 if resp.Code != http.StatusOK {
567 t.Fatalf("checkout webhook status=%d body=%s", resp.Code, resp.Body.String())
568 }
569 state, err := orgbilling.GetOrgBillingState(ctx, orgbilling.Deps{Pool: pool}, orgID)
570 if err != nil {
571 t.Fatalf("GetOrgBillingState: %v", err)
572 }
573 if !state.StripeCustomerID.Valid || state.StripeCustomerID.String != "cus_test_checkout_completed" {
574 t.Fatalf("expected checkout customer saved, got %+v", state.StripeCustomerID)
575 }
576 if state.Plan != orgbilling.PlanFree || state.SubscriptionStatus != orgbilling.SubscriptionStatusNone {
577 t.Fatalf("checkout completion must not activate paid state by itself: %+v", state)
578 }
579 }
580
581 func TestOrgBillingWebhookHandlesInvoiceFailureRecoveryAndCancellation(t *testing.T) {
582 t.Parallel()
583 ctx := context.Background()
584 pool := dbtest.NewTestDB(t)
585 ownerID := insertOrgAvatarUser(t, pool, "owner")
586 orgID := insertOrgAvatarOrg(t, pool, ownerID, "acme")
587 deps := orgbilling.Deps{Pool: pool}
588 if _, err := orgbilling.SetStripeCustomer(ctx, deps, orgID, "cus_test_lifecycle"); err != nil {
589 t.Fatalf("SetStripeCustomer: %v", err)
590 }
591 start := time.Now().UTC().Truncate(time.Second)
592 if _, err := orgbilling.ApplySubscriptionSnapshot(ctx, deps, orgbilling.SubscriptionSnapshot{
593 OrgID: orgID,
594 Plan: orgbilling.PlanTeam,
595 Status: orgbilling.SubscriptionStatusActive,
596 StripeSubscriptionID: "sub_test_lifecycle",
597 StripeSubscriptionItemID: "si_test_lifecycle",
598 CurrentPeriodStart: start,
599 CurrentPeriodEnd: start.Add(30 * 24 * time.Hour),
600 LastWebhookEventID: "evt_seed_active",
601 }); err != nil {
602 t.Fatalf("ApplySubscriptionSnapshot: %v", err)
603 }
604
605 var current stripeapi.Event
606 fake := &fakeStripeRemote{
607 verifyWebhookFn: func(_ []byte, _ string) (stripeapi.Event, error) {
608 return current, nil
609 },
610 }
611 mux := newOrgBillingMux(t, pool, ownerID, fake)
612
613 current = stripeTestEvent(t, "evt_invoice_failed", "invoice.payment_failed", map[string]any{
614 "id": "in_test_lifecycle",
615 "object": "invoice",
616 "customer": "cus_test_lifecycle",
617 "status": "open",
618 "number": "INV-FAILED",
619 "currency": "usd",
620 "amount_due": int64(400),
621 "amount_paid": int64(0),
622 "amount_remaining": int64(400),
623 "hosted_invoice_url": "https://pay.stripe.test/invoice",
624 "invoice_pdf": "https://pay.stripe.test/invoice.pdf",
625 "period_start": start.Unix(),
626 "period_end": start.Add(30 * 24 * time.Hour).Unix(),
627 "due_date": start.Add(3 * 24 * time.Hour).Unix(),
628 "status_transitions": map[string]any{},
629 "subscription_details": map[string]any{},
630 })
631 resp := postBillingWebhook(t, mux, "evt_invoice_failed")
632 if resp.Code != http.StatusOK {
633 t.Fatalf("payment_failed webhook status=%d body=%s", resp.Code, resp.Body.String())
634 }
635 state, err := orgbilling.GetOrgBillingState(ctx, deps, orgID)
636 if err != nil {
637 t.Fatalf("GetOrgBillingState after failed payment: %v", err)
638 }
639 if state.Plan != orgbilling.PlanTeam || state.SubscriptionStatus != orgbilling.SubscriptionStatusPastDue {
640 t.Fatalf("payment_failed should keep Team plan and mark past_due: %+v", state)
641 }
642 if !state.LockedAt.Valid || !state.LockReason.Valid || state.LockReason.BillingLockReason != billingdb.BillingLockReasonPastDue {
643 t.Fatalf("payment_failed should set past_due lock fields: %+v", state)
644 }
645 if !state.GraceUntil.Valid {
646 t.Fatalf("payment_failed should set grace_until: %+v", state)
647 }
648
649 current = stripeTestEvent(t, "evt_invoice_paid", "invoice.payment_succeeded", map[string]any{
650 "id": "in_test_lifecycle",
651 "object": "invoice",
652 "customer": "cus_test_lifecycle",
653 "status": "paid",
654 "number": "INV-FAILED",
655 "currency": "usd",
656 "amount_due": int64(400),
657 "amount_paid": int64(400),
658 "amount_remaining": int64(0),
659 "period_start": start.Unix(),
660 "period_end": start.Add(30 * 24 * time.Hour).Unix(),
661 "status_transitions": map[string]any{"paid_at": start.Add(time.Hour).Unix()},
662 "subscription_details": map[string]any{},
663 })
664 resp = postBillingWebhook(t, mux, "evt_invoice_paid")
665 if resp.Code != http.StatusOK {
666 t.Fatalf("payment_succeeded webhook status=%d body=%s", resp.Code, resp.Body.String())
667 }
668 state, err = orgbilling.GetOrgBillingState(ctx, deps, orgID)
669 if err != nil {
670 t.Fatalf("GetOrgBillingState after paid invoice: %v", err)
671 }
672 if state.Plan != orgbilling.PlanTeam || state.SubscriptionStatus != orgbilling.SubscriptionStatusActive {
673 t.Fatalf("payment_succeeded should recover active Team state: %+v", state)
674 }
675 if state.LockedAt.Valid || state.LockReason.Valid || state.GraceUntil.Valid {
676 t.Fatalf("payment_succeeded should clear billing action lock: %+v", state)
677 }
678 if state.PastDueAt.Valid {
679 t.Fatalf("payment_succeeded should clear past_due_at: %+v", state)
680 }
681
682 current = stripeTestEvent(t, "evt_subscription_deleted", "customer.subscription.deleted", map[string]any{
683 "id": "sub_test_lifecycle",
684 "object": "subscription",
685 "customer": "cus_test_lifecycle",
686 "status": "canceled",
687 "cancel_at_period_end": false,
688 "trial_end": int64(0),
689 "canceled_at": start.Add(2 * time.Hour).Unix(),
690 "metadata": map[string]string{stripebilling.MetadataOrgID: strconv.FormatInt(orgID, 10)},
691 "items": map[string]any{"data": []map[string]any{{
692 "id": "si_test_lifecycle",
693 "current_period_start": start.Unix(),
694 "current_period_end": start.Add(30 * 24 * time.Hour).Unix(),
695 }}},
696 })
697 resp = postBillingWebhook(t, mux, "evt_subscription_deleted")
698 if resp.Code != http.StatusOK {
699 t.Fatalf("subscription deleted webhook status=%d body=%s", resp.Code, resp.Body.String())
700 }
701 state, err = orgbilling.GetOrgBillingState(ctx, deps, orgID)
702 if err != nil {
703 t.Fatalf("GetOrgBillingState after cancellation: %v", err)
704 }
705 if state.Plan != orgbilling.PlanFree || state.SubscriptionStatus != orgbilling.SubscriptionStatusCanceled {
706 t.Fatalf("subscription deletion should downgrade to Free canceled state: %+v", state)
707 }
708 if !state.LockedAt.Valid || !state.LockReason.Valid || state.LockReason.BillingLockReason != billingdb.BillingLockReasonCanceled {
709 t.Fatalf("subscription deletion should set canceled lock fields: %+v", state)
710 }
711 }
712
713 type fakeStripeRemote struct {
714 createCustomerFn func(context.Context, stripebilling.CustomerInput) (stripebilling.Customer, error)
715 createCheckoutFn func(context.Context, stripebilling.CheckoutInput) (stripebilling.CheckoutSession, error)
716 createPortalFn func(context.Context, stripebilling.PortalInput) (stripebilling.PortalSession, error)
717 updateQuantityFn func(context.Context, stripebilling.SeatQuantityInput) error
718 verifyWebhookFn func([]byte, string) (stripeapi.Event, error)
719 }
720
721 func (f *fakeStripeRemote) CreateCustomer(ctx context.Context, in stripebilling.CustomerInput) (stripebilling.Customer, error) {
722 if f.createCustomerFn == nil {
723 return stripebilling.Customer{}, nil
724 }
725 return f.createCustomerFn(ctx, in)
726 }
727
728 func (f *fakeStripeRemote) CreateCheckoutSession(ctx context.Context, in stripebilling.CheckoutInput) (stripebilling.CheckoutSession, error) {
729 if f.createCheckoutFn == nil {
730 return stripebilling.CheckoutSession{}, nil
731 }
732 return f.createCheckoutFn(ctx, in)
733 }
734
735 func (f *fakeStripeRemote) CreatePortalSession(ctx context.Context, in stripebilling.PortalInput) (stripebilling.PortalSession, error) {
736 if f.createPortalFn == nil {
737 return stripebilling.PortalSession{}, nil
738 }
739 return f.createPortalFn(ctx, in)
740 }
741
742 func (f *fakeStripeRemote) UpdateSubscriptionItemQuantity(ctx context.Context, in stripebilling.SeatQuantityInput) error {
743 if f.updateQuantityFn == nil {
744 return nil
745 }
746 return f.updateQuantityFn(ctx, in)
747 }
748
749 func (f *fakeStripeRemote) VerifyWebhook(payload []byte, signatureHeader string) (stripeapi.Event, error) {
750 if f.verifyWebhookFn == nil {
751 return stripeapi.Event{}, nil
752 }
753 return f.verifyWebhookFn(payload, signatureHeader)
754 }
755
756 func stripeTestEvent(t *testing.T, id, typ string, raw map[string]any) stripeapi.Event {
757 t.Helper()
758 return stripeapi.Event{
759 ID: id,
760 Type: stripeapi.EventType(typ),
761 APIVersion: "2024-06-20",
762 Data: &stripeapi.EventData{Raw: mustJSONRaw(t, raw)},
763 }
764 }
765
766 func mustJSONRaw(t *testing.T, v any) []byte {
767 t.Helper()
768 raw, err := json.Marshal(v)
769 if err != nil {
770 t.Fatalf("marshal stripe raw object: %v", err)
771 }
772 return raw
773 }
774
775 func postBillingWebhook(t *testing.T, mux *chi.Mux, eventID string) *httptest.ResponseRecorder {
776 t.Helper()
777 req := httptest.NewRequest(http.MethodPost, "/stripe/webhook", strings.NewReader(`{"id":"`+eventID+`"}`))
778 req.Header.Set("Stripe-Signature", "sig_test")
779 resp := httptest.NewRecorder()
780 mux.ServeHTTP(resp, req)
781 return resp
782 }
783
784 func newOrgBillingMux(t *testing.T, pool *pgxpool.Pool, ownerID int64, remote stripebilling.Remote) *chi.Mux {
785 t.Helper()
786 return newOrgBillingMuxForUser(t, pool, middleware.CurrentUser{ID: ownerID, Username: "owner"}, remote)
787 }
788
789 // newOrgBillingMuxWithPrices is the price-aware variant used by the
790 // PRO08 cross-kind guard tests. Default mux leaves Team / Pro price
791 // IDs empty (guard short-circuits); these tests need them populated
792 // so the guard actually exercises its logic.
793 func newOrgBillingMuxWithPrices(t *testing.T, pool *pgxpool.Pool, ownerID int64, remote stripebilling.Remote, teamPriceID, proPriceID string) *chi.Mux {
794 t.Helper()
795 return newOrgBillingMuxFull(t, pool, middleware.CurrentUser{ID: ownerID, Username: "owner"}, remote, teamPriceID, proPriceID)
796 }
797
798 func newOrgBillingMuxForUser(t *testing.T, pool *pgxpool.Pool, viewer middleware.CurrentUser, remote stripebilling.Remote) *chi.Mux {
799 t.Helper()
800 return newOrgBillingMuxFull(t, pool, viewer, remote, "", "")
801 }
802
803 func newOrgBillingMuxFull(t *testing.T, pool *pgxpool.Pool, viewer middleware.CurrentUser, remote stripebilling.Remote, teamPriceID, proPriceID string) *chi.Mux {
804 t.Helper()
805 tmplFS := fstest.MapFS{
806 "_layout.html": {Data: []byte(`{{ define "layout" }}<html><body>{{ template "page" . }}</body></html>{{ end }}`)},
807 "orgs/billing_result.html": {Data: []byte(`{{ define "page" }}RESULT={{ .Result }};HEADING={{ .Heading }};MESSAGE={{ .Message }};BILLING={{ .BillingPath }}{{ end }}`)},
808 "orgs/settings_billing.html": {Data: []byte(`{{ define "page" }}{{ with .Error }}ERROR={{ . }}{{ end }}{{ with .Notice }}NOTICE={{ . }}{{ end }}{{ with .BillingAlert }}{{ if .Message }}ALERT={{ .Message }}{{ end }}{{ end }}{{ with .Usage.Alert }}{{ if .Message }}USAGE_ALERT={{ .Message }};{{ end }}{{ end }}SEATS={{ .Seats.ActiveMembers }}/{{ .Seats.BillableSeats }}/{{ .Seats.PendingInvites }};{{ range .Usage.Rows }}USAGE={{ .Key }}:{{ .UsedLabel }}/{{ .LimitLabel }}/{{ .PercentLabel }}/{{ .StatusClass }};{{ end }}{{ range .Summary }}{{ if eq .Label "Payment source" }}PAYMENT={{ .Detail }};{{ end }}{{ end }}{{ if .IsSiteAdmin }}DEBUG={{ .Debug.StripeCustomerID }}|{{ .Debug.StripeSubscriptionID }}|{{ .Debug.StripeSubscriptionItemID }}|{{ .Debug.LastWebhookEventID }}|{{ .Debug.LastWebhookStatus }};{{ range .Debug.QuotaOverrides }}OVERRIDE={{ .Kind }}:{{ .Limit }};{{ end }}{{ end }}{{ range .Invoices }}INVOICE={{ .Number }};{{ end }}{{ end }}`)},
809 "errors/403.html": {Data: []byte(`{{ define "page" }}403{{ end }}`)},
810 "errors/404.html": {Data: []byte(`{{ define "page" }}404{{ end }}`)},
811 "errors/500.html": {Data: []byte(`{{ define "page" }}500{{ end }}`)},
812 }
813 rr, err := render.New(tmplFS, render.Options{})
814 if err != nil {
815 t.Fatalf("render.New: %v", err)
816 }
817 h, err := orgsh.New(orgsh.Deps{
818 Logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
819 Render: rr,
820 Pool: pool,
821 BaseURL: "https://shithub.example",
822 BillingEnabled: true,
823 BillingGracePeriod: 14 * 24 * time.Hour,
824 Stripe: remote,
825 StripeSuccessURL: "https://shithub.example/organizations/{org}/billing/success",
826 StripeCancelURL: "https://shithub.example/organizations/{org}/billing/cancel",
827 StripePortalReturnURL: "https://shithub.example/organizations/{org}/settings/billing",
828 StripeTeamPriceID: teamPriceID,
829 StripeProPriceID: proPriceID,
830 })
831 if err != nil {
832 t.Fatalf("orgsh.New: %v", err)
833 }
834 mux := chi.NewRouter()
835 mux.Use(func(next http.Handler) http.Handler {
836 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
837 next.ServeHTTP(w, r.WithContext(middleware.WithCurrentUserForTest(r.Context(), viewer)))
838 })
839 })
840 h.MountCreate(mux)
841 h.MountBillingWebhook(mux)
842 return mux
843 }
844
845 func insertBillingPendingInvitation(t *testing.T, db *pgxpool.Pool, orgID int64, email string, token []byte) {
846 t.Helper()
847 if _, err := db.Exec(context.Background(), `
848 INSERT INTO org_invitations (org_id, target_email, role, token_hash, expires_at)
849 VALUES ($1, $2, 'member', $3, now() + interval '1 day')
850 `, orgID, email, token); err != nil {
851 t.Fatalf("insert pending billing invitation: %v", err)
852 }
853 }
854