| 1 | // SPDX-License-Identifier: AGPL-3.0-or-later |
| 2 | |
| 3 | package orgs_test |
| 4 | |
| 5 | import ( |
| 6 | "context" |
| 7 | "encoding/json" |
| 8 | "io" |
| 9 | "log/slog" |
| 10 | "net/http" |
| 11 | "net/http/httptest" |
| 12 | "net/url" |
| 13 | "strconv" |
| 14 | "strings" |
| 15 | "testing" |
| 16 | "testing/fstest" |
| 17 | "time" |
| 18 | |
| 19 | "github.com/go-chi/chi/v5" |
| 20 | "github.com/jackc/pgx/v5/pgxpool" |
| 21 | stripeapi "github.com/stripe/stripe-go/v85" |
| 22 | |
| 23 | orgbilling "github.com/tenseleyFlow/shithub/internal/billing" |
| 24 | billingdb "github.com/tenseleyFlow/shithub/internal/billing/sqlc" |
| 25 | "github.com/tenseleyFlow/shithub/internal/billing/stripebilling" |
| 26 | "github.com/tenseleyFlow/shithub/internal/testing/dbtest" |
| 27 | orgsh "github.com/tenseleyFlow/shithub/internal/web/handlers/orgs" |
| 28 | "github.com/tenseleyFlow/shithub/internal/web/middleware" |
| 29 | "github.com/tenseleyFlow/shithub/internal/web/render" |
| 30 | ) |
| 31 | |
| 32 | func TestOrgBillingCheckoutRedirectsToStripeAndCreatesCustomer(t *testing.T) { |
| 33 | t.Parallel() |
| 34 | ctx := context.Background() |
| 35 | pool := dbtest.NewTestDB(t) |
| 36 | ownerID := insertOrgAvatarUser(t, pool, "owner") |
| 37 | orgID := insertOrgAvatarOrg(t, pool, ownerID, "acme") |
| 38 | fake := &fakeStripeRemote{ |
| 39 | createCustomerFn: func(_ context.Context, in stripebilling.CustomerInput) (stripebilling.Customer, error) { |
| 40 | if in.OrgID != orgID || in.OrgSlug != "acme" { |
| 41 | t.Fatalf("unexpected customer input: %+v", in) |
| 42 | } |
| 43 | return stripebilling.Customer{ID: "cus_test_checkout"}, nil |
| 44 | }, |
| 45 | createCheckoutFn: func(_ context.Context, in stripebilling.CheckoutInput) (stripebilling.CheckoutSession, error) { |
| 46 | if in.CustomerID != "cus_test_checkout" { |
| 47 | t.Fatalf("checkout customer = %q", in.CustomerID) |
| 48 | } |
| 49 | if in.SeatCount != 1 { |
| 50 | t.Fatalf("checkout seats = %d, want 1", in.SeatCount) |
| 51 | } |
| 52 | if !strings.Contains(in.SuccessURL, "/organizations/acme/billing/success") { |
| 53 | t.Fatalf("success url = %q", in.SuccessURL) |
| 54 | } |
| 55 | if !strings.Contains(in.CancelURL, "/organizations/acme/billing/cancel") { |
| 56 | t.Fatalf("cancel url = %q", in.CancelURL) |
| 57 | } |
| 58 | return stripebilling.CheckoutSession{ID: "cs_test", URL: "https://checkout.stripe.test/session"}, nil |
| 59 | }, |
| 60 | } |
| 61 | mux := newOrgBillingMux(t, pool, ownerID, fake) |
| 62 | |
| 63 | resp := httptest.NewRecorder() |
| 64 | req := newOrgFormRequest(http.MethodPost, "/organizations/acme/billing/checkout", url.Values{}) |
| 65 | mux.ServeHTTP(resp, req) |
| 66 | if resp.Code != http.StatusSeeOther { |
| 67 | t.Fatalf("checkout status=%d body=%s", resp.Code, resp.Body.String()) |
| 68 | } |
| 69 | if got := resp.Header().Get("Location"); got != "https://checkout.stripe.test/session" { |
| 70 | t.Fatalf("checkout redirect=%q", got) |
| 71 | } |
| 72 | state, err := orgbilling.GetOrgBillingState(ctx, orgbilling.Deps{Pool: pool}, orgID) |
| 73 | if err != nil { |
| 74 | t.Fatalf("GetOrgBillingState: %v", err) |
| 75 | } |
| 76 | if !state.StripeCustomerID.Valid || state.StripeCustomerID.String != "cus_test_checkout" { |
| 77 | t.Fatalf("expected stripe customer saved, got %+v", state.StripeCustomerID) |
| 78 | } |
| 79 | } |
| 80 | |
| 81 | func TestOrgBillingPortalRedirectsToStripe(t *testing.T) { |
| 82 | t.Parallel() |
| 83 | ctx := context.Background() |
| 84 | pool := dbtest.NewTestDB(t) |
| 85 | ownerID := insertOrgAvatarUser(t, pool, "owner") |
| 86 | orgID := insertOrgAvatarOrg(t, pool, ownerID, "acme") |
| 87 | if _, err := orgbilling.SetStripeCustomer(ctx, orgbilling.Deps{Pool: pool}, orgID, "cus_test_portal"); err != nil { |
| 88 | t.Fatalf("SetStripeCustomer: %v", err) |
| 89 | } |
| 90 | fake := &fakeStripeRemote{ |
| 91 | createPortalFn: func(_ context.Context, in stripebilling.PortalInput) (stripebilling.PortalSession, error) { |
| 92 | if in.CustomerID != "cus_test_portal" { |
| 93 | t.Fatalf("portal customer = %q", in.CustomerID) |
| 94 | } |
| 95 | if !strings.Contains(in.ReturnURL, "/organizations/acme/settings/billing") { |
| 96 | t.Fatalf("portal return url = %q", in.ReturnURL) |
| 97 | } |
| 98 | return stripebilling.PortalSession{ID: "bps_test", URL: "https://billing.stripe.test/session"}, nil |
| 99 | }, |
| 100 | } |
| 101 | mux := newOrgBillingMux(t, pool, ownerID, fake) |
| 102 | |
| 103 | resp := httptest.NewRecorder() |
| 104 | req := newOrgFormRequest(http.MethodPost, "/organizations/acme/billing/portal", url.Values{}) |
| 105 | mux.ServeHTTP(resp, req) |
| 106 | if resp.Code != http.StatusSeeOther { |
| 107 | t.Fatalf("portal status=%d body=%s", resp.Code, resp.Body.String()) |
| 108 | } |
| 109 | if got := resp.Header().Get("Location"); got != "https://billing.stripe.test/session" { |
| 110 | t.Fatalf("portal redirect=%q", got) |
| 111 | } |
| 112 | } |
| 113 | |
| 114 | func TestOrgBillingResultPagesRenderPostCheckoutState(t *testing.T) { |
| 115 | t.Parallel() |
| 116 | pool := dbtest.NewTestDB(t) |
| 117 | ownerID := insertOrgAvatarUser(t, pool, "owner") |
| 118 | insertOrgAvatarOrg(t, pool, ownerID, "acme") |
| 119 | mux := newOrgBillingMux(t, pool, ownerID, &fakeStripeRemote{}) |
| 120 | |
| 121 | for _, tc := range []struct { |
| 122 | path string |
| 123 | want string |
| 124 | }{ |
| 125 | {path: "/organizations/acme/billing/success", want: "RESULT=success;HEADING=Checkout complete"}, |
| 126 | {path: "/organizations/acme/billing/cancel", want: "RESULT=canceled;HEADING=Checkout canceled"}, |
| 127 | } { |
| 128 | resp := httptest.NewRecorder() |
| 129 | req := newOrgFormRequest(http.MethodGet, tc.path, nil) |
| 130 | mux.ServeHTTP(resp, req) |
| 131 | if resp.Code != http.StatusOK { |
| 132 | t.Fatalf("%s status=%d body=%s", tc.path, resp.Code, resp.Body.String()) |
| 133 | } |
| 134 | if !strings.Contains(resp.Body.String(), tc.want) { |
| 135 | t.Fatalf("%s missing %q in body %s", tc.path, tc.want, resp.Body.String()) |
| 136 | } |
| 137 | } |
| 138 | } |
| 139 | |
| 140 | func TestOrgBillingSettingsRequiresOwner(t *testing.T) { |
| 141 | t.Parallel() |
| 142 | pool := dbtest.NewTestDB(t) |
| 143 | ownerID := insertOrgAvatarUser(t, pool, "owner") |
| 144 | insertOrgAvatarOrg(t, pool, ownerID, "acme") |
| 145 | memberID := insertOrgAvatarUser(t, pool, "member") |
| 146 | mux := newOrgBillingMuxForUser(t, pool, middleware.CurrentUser{ID: memberID, Username: "member"}, &fakeStripeRemote{}) |
| 147 | |
| 148 | resp := httptest.NewRecorder() |
| 149 | req := newOrgFormRequest(http.MethodGet, "/organizations/acme/settings/billing", nil) |
| 150 | mux.ServeHTTP(resp, req) |
| 151 | if resp.Code != http.StatusForbidden { |
| 152 | t.Fatalf("settings status=%d body=%s", resp.Code, resp.Body.String()) |
| 153 | } |
| 154 | } |
| 155 | |
| 156 | func TestOrgBillingSettingsRendersSeatBreakdownAndHidesStripeIDs(t *testing.T) { |
| 157 | t.Parallel() |
| 158 | ctx := context.Background() |
| 159 | pool := dbtest.NewTestDB(t) |
| 160 | ownerID := insertOrgAvatarUser(t, pool, "owner") |
| 161 | orgID := insertOrgAvatarOrg(t, pool, ownerID, "acme") |
| 162 | if _, err := orgbilling.SetStripeCustomer(ctx, orgbilling.Deps{Pool: pool}, orgID, "cus_owner_secret"); err != nil { |
| 163 | t.Fatalf("SetStripeCustomer: %v", err) |
| 164 | } |
| 165 | if _, err := orgbilling.SyncSeatSnapshot(ctx, orgbilling.Deps{Pool: pool}, orgbilling.SeatSnapshot{ |
| 166 | OrgID: orgID, |
| 167 | ActiveMembers: 3, |
| 168 | BillableSeats: 3, |
| 169 | Source: "test", |
| 170 | }); err != nil { |
| 171 | t.Fatalf("SyncSeatSnapshot: %v", err) |
| 172 | } |
| 173 | insertBillingPendingInvitation(t, pool, orgID, "pending@example.com", []byte{1, 2, 3}) |
| 174 | mux := newOrgBillingMux(t, pool, ownerID, &fakeStripeRemote{}) |
| 175 | |
| 176 | resp := httptest.NewRecorder() |
| 177 | req := newOrgFormRequest(http.MethodGet, "/organizations/acme/settings/billing", nil) |
| 178 | mux.ServeHTTP(resp, req) |
| 179 | body := resp.Body.String() |
| 180 | if resp.Code != http.StatusOK { |
| 181 | t.Fatalf("settings status=%d body=%s", resp.Code, body) |
| 182 | } |
| 183 | if !strings.Contains(body, "SEATS=1/3/1;") { |
| 184 | t.Fatalf("settings did not render seat breakdown: %s", body) |
| 185 | } |
| 186 | if strings.Contains(body, "cus_owner_secret") { |
| 187 | t.Fatalf("owner billing page leaked Stripe customer id: %s", body) |
| 188 | } |
| 189 | } |
| 190 | |
| 191 | func TestOrgBillingSettingsRendersUsageThresholds(t *testing.T) { |
| 192 | t.Parallel() |
| 193 | ctx := context.Background() |
| 194 | pool := dbtest.NewTestDB(t) |
| 195 | ownerID := insertOrgAvatarUser(t, pool, "owner") |
| 196 | orgID := insertOrgAvatarOrg(t, pool, ownerID, "acme") |
| 197 | start, end := orgbilling.MonthlyUsagePeriod(time.Date(2026, 5, 13, 12, 0, 0, 0, time.UTC)) |
| 198 | if _, err := orgbilling.UpsertOrgUsageCounters(ctx, orgbilling.Deps{Pool: pool}, orgbilling.UsageCounterSnapshot{ |
| 199 | OrgID: orgID, |
| 200 | RepoStorageBytes: 490 * 1024 * 1024, |
| 201 | ObjectStorageBytes: 0, |
| 202 | ActionsLogBytes: 0, |
| 203 | ActionsArtifactBytes: 0, |
| 204 | ActionsMinutesUsed: 1600, |
| 205 | ActionsPeriodStart: start, |
| 206 | ActionsPeriodEnd: end, |
| 207 | CalculatedAt: start.Add(12 * time.Hour), |
| 208 | }); err != nil { |
| 209 | t.Fatalf("UpsertOrgUsageCounters: %v", err) |
| 210 | } |
| 211 | mux := newOrgBillingMux(t, pool, ownerID, &fakeStripeRemote{}) |
| 212 | |
| 213 | resp := httptest.NewRecorder() |
| 214 | req := newOrgFormRequest(http.MethodGet, "/organizations/acme/settings/billing", nil) |
| 215 | mux.ServeHTTP(resp, req) |
| 216 | body := resp.Body.String() |
| 217 | if resp.Code != http.StatusOK { |
| 218 | t.Fatalf("settings status=%d body=%s", resp.Code, body) |
| 219 | } |
| 220 | if !strings.Contains(body, "USAGE=storage:490 MiB/500 MiB/98%/is-danger;") { |
| 221 | t.Fatalf("settings did not render storage usage threshold: %s", body) |
| 222 | } |
| 223 | if !strings.Contains(body, "USAGE=actions-minutes:1600 minutes/2000 minutes/80%/is-warning;") { |
| 224 | t.Fatalf("settings did not render actions usage threshold: %s", body) |
| 225 | } |
| 226 | if !strings.Contains(body, "USAGE_ALERT=This organization has used at least 95% of its storage quota.") { |
| 227 | t.Fatalf("settings did not render quota warning: %s", body) |
| 228 | } |
| 229 | } |
| 230 | |
| 231 | func TestOrgBillingSettingsAppliesQuotaOverrides(t *testing.T) { |
| 232 | t.Parallel() |
| 233 | ctx := context.Background() |
| 234 | pool := dbtest.NewTestDB(t) |
| 235 | ownerID := insertOrgAvatarUser(t, pool, "owner") |
| 236 | orgID := insertOrgAvatarOrg(t, pool, ownerID, "acme") |
| 237 | start, end := orgbilling.MonthlyUsagePeriod(time.Date(2026, 5, 13, 12, 0, 0, 0, time.UTC)) |
| 238 | deps := orgbilling.Deps{Pool: pool} |
| 239 | if _, err := orgbilling.UpsertOrgUsageCounters(ctx, deps, orgbilling.UsageCounterSnapshot{ |
| 240 | OrgID: orgID, |
| 241 | RepoStorageBytes: 6 * 1024 * 1024 * 1024, |
| 242 | ObjectStorageBytes: 0, |
| 243 | ActionsLogBytes: 0, |
| 244 | ActionsArtifactBytes: 0, |
| 245 | ActionsMinutesUsed: 2100, |
| 246 | ActionsPeriodStart: start, |
| 247 | ActionsPeriodEnd: end, |
| 248 | CalculatedAt: start.Add(12 * time.Hour), |
| 249 | }); err != nil { |
| 250 | t.Fatalf("UpsertOrgUsageCounters: %v", err) |
| 251 | } |
| 252 | if _, err := orgbilling.UpsertOrgQuotaOverride(ctx, deps, orgbilling.QuotaOverrideInput{ |
| 253 | OrgID: orgID, |
| 254 | Kind: orgbilling.QuotaKindStorageBytes, |
| 255 | LimitValue: 10 * 1024 * 1024 * 1024, |
| 256 | Note: "temporary migration", |
| 257 | CreatedByUserID: ownerID, |
| 258 | }); err != nil { |
| 259 | t.Fatalf("UpsertOrgQuotaOverride storage: %v", err) |
| 260 | } |
| 261 | if _, err := orgbilling.UpsertOrgQuotaOverride(ctx, deps, orgbilling.QuotaOverrideInput{ |
| 262 | OrgID: orgID, |
| 263 | Kind: orgbilling.QuotaKindActionsMinutes, |
| 264 | Unlimited: true, |
| 265 | CreatedByUserID: ownerID, |
| 266 | }); err != nil { |
| 267 | t.Fatalf("UpsertOrgQuotaOverride actions: %v", err) |
| 268 | } |
| 269 | mux := newOrgBillingMuxForUser(t, pool, middleware.CurrentUser{ID: ownerID, Username: "owner", IsSiteAdmin: true}, &fakeStripeRemote{}) |
| 270 | |
| 271 | resp := httptest.NewRecorder() |
| 272 | req := newOrgFormRequest(http.MethodGet, "/organizations/acme/settings/billing", nil) |
| 273 | mux.ServeHTTP(resp, req) |
| 274 | body := resp.Body.String() |
| 275 | if resp.Code != http.StatusOK { |
| 276 | t.Fatalf("settings status=%d body=%s", resp.Code, body) |
| 277 | } |
| 278 | if !strings.Contains(body, "USAGE=storage:6 GiB/10 GiB override/60%/is-ok;") { |
| 279 | t.Fatalf("settings did not apply storage override: %s", body) |
| 280 | } |
| 281 | if !strings.Contains(body, "USAGE=actions-minutes:2100 minutes/Unlimited override/Unlimited/is-ok;") { |
| 282 | t.Fatalf("settings did not apply actions override: %s", body) |
| 283 | } |
| 284 | if !strings.Contains(body, "OVERRIDE=Storage:10 GiB;") || !strings.Contains(body, "OVERRIDE=Actions minutes:Unlimited;") { |
| 285 | t.Fatalf("settings did not render site-admin quota overrides: %s", body) |
| 286 | } |
| 287 | } |
| 288 | |
| 289 | func TestOrgBillingSiteAdminCanManageQuotaOverrides(t *testing.T) { |
| 290 | t.Parallel() |
| 291 | ctx := context.Background() |
| 292 | pool := dbtest.NewTestDB(t) |
| 293 | ownerID := insertOrgAvatarUser(t, pool, "owner") |
| 294 | adminID := insertOrgAvatarUser(t, pool, "admin") |
| 295 | orgID := insertOrgAvatarOrg(t, pool, ownerID, "acme") |
| 296 | mux := newOrgBillingMuxForUser(t, pool, middleware.CurrentUser{ID: adminID, Username: "admin", IsSiteAdmin: true}, &fakeStripeRemote{}) |
| 297 | |
| 298 | resp := httptest.NewRecorder() |
| 299 | req := newOrgFormRequest(http.MethodPost, "/organizations/acme/billing/quota-overrides", url.Values{ |
| 300 | "kind": {"storage_bytes"}, |
| 301 | "limit_value": {"1073741824"}, |
| 302 | "note": {"support migration"}, |
| 303 | }) |
| 304 | mux.ServeHTTP(resp, req) |
| 305 | if resp.Code != http.StatusSeeOther { |
| 306 | t.Fatalf("save status=%d body=%s", resp.Code, resp.Body.String()) |
| 307 | } |
| 308 | if got := resp.Header().Get("Location"); got != "/organizations/acme/settings/billing?notice=quota-override-saved" { |
| 309 | t.Fatalf("save redirect=%q", got) |
| 310 | } |
| 311 | override, err := orgbilling.GetOrgQuotaOverride(ctx, orgbilling.Deps{Pool: pool}, orgID, orgbilling.QuotaKindStorageBytes) |
| 312 | if err != nil { |
| 313 | t.Fatalf("GetOrgQuotaOverride: %v", err) |
| 314 | } |
| 315 | if override.Unlimited || !override.LimitValue.Valid || override.LimitValue.Int64 != 1073741824 || override.Note != "support migration" { |
| 316 | t.Fatalf("unexpected override: %+v", override) |
| 317 | } |
| 318 | if !override.CreatedByUserID.Valid || override.CreatedByUserID.Int64 != adminID { |
| 319 | t.Fatalf("created_by_user_id=%+v, want %d", override.CreatedByUserID, adminID) |
| 320 | } |
| 321 | |
| 322 | resp = httptest.NewRecorder() |
| 323 | req = newOrgFormRequest(http.MethodPost, "/organizations/acme/billing/quota-overrides", url.Values{ |
| 324 | "kind": {"actions_minutes"}, |
| 325 | "unlimited": {"1"}, |
| 326 | }) |
| 327 | mux.ServeHTTP(resp, req) |
| 328 | if resp.Code != http.StatusSeeOther { |
| 329 | t.Fatalf("unlimited save status=%d body=%s", resp.Code, resp.Body.String()) |
| 330 | } |
| 331 | actionsOverride, err := orgbilling.GetOrgQuotaOverride(ctx, orgbilling.Deps{Pool: pool}, orgID, orgbilling.QuotaKindActionsMinutes) |
| 332 | if err != nil { |
| 333 | t.Fatalf("GetOrgQuotaOverride actions: %v", err) |
| 334 | } |
| 335 | if !actionsOverride.Unlimited || actionsOverride.LimitValue.Valid { |
| 336 | t.Fatalf("unexpected unlimited override: %+v", actionsOverride) |
| 337 | } |
| 338 | |
| 339 | resp = httptest.NewRecorder() |
| 340 | req = newOrgFormRequest(http.MethodPost, "/organizations/acme/billing/quota-overrides/delete", url.Values{ |
| 341 | "kind": {"storage_bytes"}, |
| 342 | }) |
| 343 | mux.ServeHTTP(resp, req) |
| 344 | if resp.Code != http.StatusSeeOther { |
| 345 | t.Fatalf("delete status=%d body=%s", resp.Code, resp.Body.String()) |
| 346 | } |
| 347 | if got := resp.Header().Get("Location"); got != "/organizations/acme/settings/billing?notice=quota-override-cleared" { |
| 348 | t.Fatalf("delete redirect=%q", got) |
| 349 | } |
| 350 | overrides, err := orgbilling.ListOrgQuotaOverrides(ctx, orgbilling.Deps{Pool: pool}, orgID) |
| 351 | if err != nil { |
| 352 | t.Fatalf("ListOrgQuotaOverrides: %v", err) |
| 353 | } |
| 354 | if len(overrides) != 1 || overrides[0].Kind != orgbilling.QuotaKindActionsMinutes { |
| 355 | t.Fatalf("unexpected remaining overrides: %+v", overrides) |
| 356 | } |
| 357 | } |
| 358 | |
| 359 | func TestOrgBillingQuotaOverridesRequireSiteAdmin(t *testing.T) { |
| 360 | t.Parallel() |
| 361 | ctx := context.Background() |
| 362 | pool := dbtest.NewTestDB(t) |
| 363 | ownerID := insertOrgAvatarUser(t, pool, "owner") |
| 364 | orgID := insertOrgAvatarOrg(t, pool, ownerID, "acme") |
| 365 | mux := newOrgBillingMux(t, pool, ownerID, &fakeStripeRemote{}) |
| 366 | |
| 367 | resp := httptest.NewRecorder() |
| 368 | req := newOrgFormRequest(http.MethodPost, "/organizations/acme/billing/quota-overrides", url.Values{ |
| 369 | "kind": {"storage_bytes"}, |
| 370 | "limit_value": {"1"}, |
| 371 | }) |
| 372 | mux.ServeHTTP(resp, req) |
| 373 | if resp.Code != http.StatusNotFound { |
| 374 | t.Fatalf("save status=%d body=%s", resp.Code, resp.Body.String()) |
| 375 | } |
| 376 | overrides, err := orgbilling.ListOrgQuotaOverrides(ctx, orgbilling.Deps{Pool: pool}, orgID) |
| 377 | if err != nil { |
| 378 | t.Fatalf("ListOrgQuotaOverrides: %v", err) |
| 379 | } |
| 380 | if len(overrides) != 0 { |
| 381 | t.Fatalf("non-admin created overrides: %+v", overrides) |
| 382 | } |
| 383 | } |
| 384 | |
| 385 | func TestOrgBillingSettingsSiteAdminDebugShowsProviderState(t *testing.T) { |
| 386 | t.Parallel() |
| 387 | ctx := context.Background() |
| 388 | pool := dbtest.NewTestDB(t) |
| 389 | ownerID := insertOrgAvatarUser(t, pool, "owner") |
| 390 | orgID := insertOrgAvatarOrg(t, pool, ownerID, "acme") |
| 391 | deps := orgbilling.Deps{Pool: pool} |
| 392 | if _, err := orgbilling.SetStripeCustomer(ctx, deps, orgID, "cus_debug"); err != nil { |
| 393 | t.Fatalf("SetStripeCustomer: %v", err) |
| 394 | } |
| 395 | if _, err := orgbilling.ApplySubscriptionSnapshot(ctx, deps, orgbilling.SubscriptionSnapshot{ |
| 396 | OrgID: orgID, |
| 397 | Plan: orgbilling.PlanTeam, |
| 398 | Status: orgbilling.SubscriptionStatusActive, |
| 399 | StripeSubscriptionID: "sub_debug", |
| 400 | StripeSubscriptionItemID: "si_debug", |
| 401 | CurrentPeriodStart: time.Now().UTC().Add(-time.Hour), |
| 402 | CurrentPeriodEnd: time.Now().UTC().Add(30 * 24 * time.Hour), |
| 403 | LastWebhookEventID: "evt_debug", |
| 404 | }); err != nil { |
| 405 | t.Fatalf("ApplySubscriptionSnapshot: %v", err) |
| 406 | } |
| 407 | if _, _, err := orgbilling.RecordWebhookEvent(ctx, deps, orgbilling.WebhookEvent{ |
| 408 | ProviderEventID: "evt_debug", |
| 409 | EventType: "customer.subscription.updated", |
| 410 | APIVersion: "2024-06-20", |
| 411 | Payload: []byte(`{"id":"evt_debug"}`), |
| 412 | }); err != nil { |
| 413 | t.Fatalf("RecordWebhookEvent: %v", err) |
| 414 | } |
| 415 | if _, err := orgbilling.MarkWebhookEventProcessed(ctx, deps, "evt_debug"); err != nil { |
| 416 | t.Fatalf("MarkWebhookEventProcessed: %v", err) |
| 417 | } |
| 418 | mux := newOrgBillingMuxForUser(t, pool, middleware.CurrentUser{ID: ownerID, Username: "owner", IsSiteAdmin: true}, &fakeStripeRemote{}) |
| 419 | |
| 420 | resp := httptest.NewRecorder() |
| 421 | req := newOrgFormRequest(http.MethodGet, "/organizations/acme/settings/billing", nil) |
| 422 | mux.ServeHTTP(resp, req) |
| 423 | body := resp.Body.String() |
| 424 | if resp.Code != http.StatusOK { |
| 425 | t.Fatalf("settings status=%d body=%s", resp.Code, body) |
| 426 | } |
| 427 | if !strings.Contains(body, "DEBUG=cus_debug|sub_debug|si_debug|evt_debug|processed;") { |
| 428 | t.Fatalf("settings did not render site-admin debug state: %s", body) |
| 429 | } |
| 430 | } |
| 431 | |
| 432 | func TestOrgBillingSettingsShowsPastDueAlert(t *testing.T) { |
| 433 | t.Parallel() |
| 434 | ctx := context.Background() |
| 435 | pool := dbtest.NewTestDB(t) |
| 436 | ownerID := insertOrgAvatarUser(t, pool, "owner") |
| 437 | orgID := insertOrgAvatarOrg(t, pool, ownerID, "acme") |
| 438 | if _, err := orgbilling.MarkPastDue(ctx, orgbilling.Deps{Pool: pool}, orgID, time.Now().UTC().Add(24*time.Hour), "evt_failed"); err != nil { |
| 439 | t.Fatalf("MarkPastDue: %v", err) |
| 440 | } |
| 441 | mux := newOrgBillingMux(t, pool, ownerID, &fakeStripeRemote{}) |
| 442 | |
| 443 | resp := httptest.NewRecorder() |
| 444 | req := newOrgFormRequest(http.MethodGet, "/organizations/acme/settings/billing", nil) |
| 445 | mux.ServeHTTP(resp, req) |
| 446 | body := resp.Body.String() |
| 447 | if resp.Code != http.StatusOK { |
| 448 | t.Fatalf("settings status=%d body=%s", resp.Code, body) |
| 449 | } |
| 450 | if !strings.Contains(body, "ALERT=Payment failed.") { |
| 451 | t.Fatalf("settings did not render past-due alert: %s", body) |
| 452 | } |
| 453 | } |
| 454 | |
| 455 | func TestOrgBillingWebhookProcessesSubscriptionAndStaysIdempotent(t *testing.T) { |
| 456 | t.Parallel() |
| 457 | ctx := context.Background() |
| 458 | pool := dbtest.NewTestDB(t) |
| 459 | ownerID := insertOrgAvatarUser(t, pool, "owner") |
| 460 | orgID := insertOrgAvatarOrg(t, pool, ownerID, "acme") |
| 461 | raw, err := json.Marshal(map[string]any{ |
| 462 | "id": "sub_test", |
| 463 | "customer": "cus_test_webhook", |
| 464 | "status": "active", |
| 465 | "cancel_at_period_end": false, |
| 466 | "trial_end": int64(0), |
| 467 | "canceled_at": int64(0), |
| 468 | "metadata": map[string]string{stripebilling.MetadataOrgID: strconv.FormatInt(orgID, 10)}, |
| 469 | "items": map[string]any{"data": []map[string]any{{ |
| 470 | "id": "si_test_webhook", |
| 471 | "current_period_start": time.Now().UTC().Add(-time.Hour).Unix(), |
| 472 | "current_period_end": time.Now().UTC().Add(30 * 24 * time.Hour).Unix(), |
| 473 | }}}, |
| 474 | }) |
| 475 | if err != nil { |
| 476 | t.Fatalf("marshal subscription raw: %v", err) |
| 477 | } |
| 478 | fake := &fakeStripeRemote{ |
| 479 | verifyWebhookFn: func(_ []byte, _ string) (stripeapi.Event, error) { |
| 480 | return stripeapi.Event{ |
| 481 | ID: "evt_sub_active", |
| 482 | Type: stripeapi.EventType("customer.subscription.updated"), |
| 483 | APIVersion: "2024-06-20", |
| 484 | Data: &stripeapi.EventData{Raw: raw}, |
| 485 | }, nil |
| 486 | }, |
| 487 | } |
| 488 | mux := newOrgBillingMux(t, pool, ownerID, fake) |
| 489 | |
| 490 | req := httptest.NewRequest(http.MethodPost, "/stripe/webhook", strings.NewReader(`{"id":"evt_sub_active"}`)) |
| 491 | req.Header.Set("Stripe-Signature", "sig_test") |
| 492 | resp := httptest.NewRecorder() |
| 493 | mux.ServeHTTP(resp, req) |
| 494 | if resp.Code != http.StatusOK { |
| 495 | t.Fatalf("first webhook status=%d body=%s", resp.Code, resp.Body.String()) |
| 496 | } |
| 497 | state, err := orgbilling.GetOrgBillingState(ctx, orgbilling.Deps{Pool: pool}, orgID) |
| 498 | if err != nil { |
| 499 | t.Fatalf("GetOrgBillingState: %v", err) |
| 500 | } |
| 501 | if state.Plan != orgbilling.PlanTeam || state.SubscriptionStatus != orgbilling.SubscriptionStatusActive { |
| 502 | t.Fatalf("unexpected billing state: %+v", state) |
| 503 | } |
| 504 | if !state.StripeCustomerID.Valid || state.StripeCustomerID.String != "cus_test_webhook" { |
| 505 | t.Fatalf("expected customer id saved, got %+v", state.StripeCustomerID) |
| 506 | } |
| 507 | if !state.StripeSubscriptionID.Valid || state.StripeSubscriptionID.String != "sub_test" { |
| 508 | t.Fatalf("expected subscription id saved, got %+v", state.StripeSubscriptionID) |
| 509 | } |
| 510 | receipt, err := billingdb.New().GetWebhookEventReceipt(ctx, pool, "evt_sub_active") |
| 511 | if err != nil { |
| 512 | t.Fatalf("GetWebhookEventReceipt: %v", err) |
| 513 | } |
| 514 | if !receipt.ProcessedAt.Valid || receipt.ProcessingAttempts != 1 { |
| 515 | t.Fatalf("unexpected receipt after first processing: %+v", receipt) |
| 516 | } |
| 517 | // PRO08 A2: subject must be recorded on the receipt after resolve. |
| 518 | if !receipt.SubjectKind.Valid || receipt.SubjectKind.BillingSubjectKind != billingdb.BillingSubjectKindOrg { |
| 519 | t.Fatalf("receipt subject_kind: got %+v, want org", receipt.SubjectKind) |
| 520 | } |
| 521 | if !receipt.SubjectID.Valid || receipt.SubjectID.Int64 != orgID { |
| 522 | t.Fatalf("receipt subject_id: got %+v, want %d", receipt.SubjectID, orgID) |
| 523 | } |
| 524 | |
| 525 | req = httptest.NewRequest(http.MethodPost, "/stripe/webhook", strings.NewReader(`{"id":"evt_sub_active"}`)) |
| 526 | req.Header.Set("Stripe-Signature", "sig_test") |
| 527 | resp = httptest.NewRecorder() |
| 528 | mux.ServeHTTP(resp, req) |
| 529 | if resp.Code != http.StatusOK { |
| 530 | t.Fatalf("duplicate webhook status=%d body=%s", resp.Code, resp.Body.String()) |
| 531 | } |
| 532 | receipt, err = billingdb.New().GetWebhookEventReceipt(ctx, pool, "evt_sub_active") |
| 533 | if err != nil { |
| 534 | t.Fatalf("GetWebhookEventReceipt duplicate: %v", err) |
| 535 | } |
| 536 | if receipt.ProcessingAttempts != 1 { |
| 537 | t.Fatalf("duplicate webhook should not reprocess receipt: %+v", receipt) |
| 538 | } |
| 539 | } |
| 540 | |
| 541 | func TestOrgBillingWebhookCheckoutCompletedStoresCustomerOnly(t *testing.T) { |
| 542 | t.Parallel() |
| 543 | ctx := context.Background() |
| 544 | pool := dbtest.NewTestDB(t) |
| 545 | ownerID := insertOrgAvatarUser(t, pool, "owner") |
| 546 | orgID := insertOrgAvatarOrg(t, pool, ownerID, "acme") |
| 547 | raw := mustJSONRaw(t, map[string]any{ |
| 548 | "id": "cs_test_completed", |
| 549 | "object": "checkout.session", |
| 550 | "customer": "cus_test_checkout_completed", |
| 551 | "client_reference_id": strconv.FormatInt(orgID, 10), |
| 552 | }) |
| 553 | fake := &fakeStripeRemote{ |
| 554 | verifyWebhookFn: func(_ []byte, _ string) (stripeapi.Event, error) { |
| 555 | return stripeapi.Event{ |
| 556 | ID: "evt_checkout_completed", |
| 557 | Type: stripeapi.EventType("checkout.session.completed"), |
| 558 | APIVersion: "2024-06-20", |
| 559 | Data: &stripeapi.EventData{Raw: raw}, |
| 560 | }, nil |
| 561 | }, |
| 562 | } |
| 563 | mux := newOrgBillingMux(t, pool, ownerID, fake) |
| 564 | |
| 565 | resp := postBillingWebhook(t, mux, "evt_checkout_completed") |
| 566 | if resp.Code != http.StatusOK { |
| 567 | t.Fatalf("checkout webhook status=%d body=%s", resp.Code, resp.Body.String()) |
| 568 | } |
| 569 | state, err := orgbilling.GetOrgBillingState(ctx, orgbilling.Deps{Pool: pool}, orgID) |
| 570 | if err != nil { |
| 571 | t.Fatalf("GetOrgBillingState: %v", err) |
| 572 | } |
| 573 | if !state.StripeCustomerID.Valid || state.StripeCustomerID.String != "cus_test_checkout_completed" { |
| 574 | t.Fatalf("expected checkout customer saved, got %+v", state.StripeCustomerID) |
| 575 | } |
| 576 | if state.Plan != orgbilling.PlanFree || state.SubscriptionStatus != orgbilling.SubscriptionStatusNone { |
| 577 | t.Fatalf("checkout completion must not activate paid state by itself: %+v", state) |
| 578 | } |
| 579 | } |
| 580 | |
| 581 | func TestOrgBillingWebhookHandlesInvoiceFailureRecoveryAndCancellation(t *testing.T) { |
| 582 | t.Parallel() |
| 583 | ctx := context.Background() |
| 584 | pool := dbtest.NewTestDB(t) |
| 585 | ownerID := insertOrgAvatarUser(t, pool, "owner") |
| 586 | orgID := insertOrgAvatarOrg(t, pool, ownerID, "acme") |
| 587 | deps := orgbilling.Deps{Pool: pool} |
| 588 | if _, err := orgbilling.SetStripeCustomer(ctx, deps, orgID, "cus_test_lifecycle"); err != nil { |
| 589 | t.Fatalf("SetStripeCustomer: %v", err) |
| 590 | } |
| 591 | start := time.Now().UTC().Truncate(time.Second) |
| 592 | if _, err := orgbilling.ApplySubscriptionSnapshot(ctx, deps, orgbilling.SubscriptionSnapshot{ |
| 593 | OrgID: orgID, |
| 594 | Plan: orgbilling.PlanTeam, |
| 595 | Status: orgbilling.SubscriptionStatusActive, |
| 596 | StripeSubscriptionID: "sub_test_lifecycle", |
| 597 | StripeSubscriptionItemID: "si_test_lifecycle", |
| 598 | CurrentPeriodStart: start, |
| 599 | CurrentPeriodEnd: start.Add(30 * 24 * time.Hour), |
| 600 | LastWebhookEventID: "evt_seed_active", |
| 601 | }); err != nil { |
| 602 | t.Fatalf("ApplySubscriptionSnapshot: %v", err) |
| 603 | } |
| 604 | |
| 605 | var current stripeapi.Event |
| 606 | fake := &fakeStripeRemote{ |
| 607 | verifyWebhookFn: func(_ []byte, _ string) (stripeapi.Event, error) { |
| 608 | return current, nil |
| 609 | }, |
| 610 | } |
| 611 | mux := newOrgBillingMux(t, pool, ownerID, fake) |
| 612 | |
| 613 | current = stripeTestEvent(t, "evt_invoice_failed", "invoice.payment_failed", map[string]any{ |
| 614 | "id": "in_test_lifecycle", |
| 615 | "object": "invoice", |
| 616 | "customer": "cus_test_lifecycle", |
| 617 | "status": "open", |
| 618 | "number": "INV-FAILED", |
| 619 | "currency": "usd", |
| 620 | "amount_due": int64(400), |
| 621 | "amount_paid": int64(0), |
| 622 | "amount_remaining": int64(400), |
| 623 | "hosted_invoice_url": "https://pay.stripe.test/invoice", |
| 624 | "invoice_pdf": "https://pay.stripe.test/invoice.pdf", |
| 625 | "period_start": start.Unix(), |
| 626 | "period_end": start.Add(30 * 24 * time.Hour).Unix(), |
| 627 | "due_date": start.Add(3 * 24 * time.Hour).Unix(), |
| 628 | "status_transitions": map[string]any{}, |
| 629 | "subscription_details": map[string]any{}, |
| 630 | }) |
| 631 | resp := postBillingWebhook(t, mux, "evt_invoice_failed") |
| 632 | if resp.Code != http.StatusOK { |
| 633 | t.Fatalf("payment_failed webhook status=%d body=%s", resp.Code, resp.Body.String()) |
| 634 | } |
| 635 | state, err := orgbilling.GetOrgBillingState(ctx, deps, orgID) |
| 636 | if err != nil { |
| 637 | t.Fatalf("GetOrgBillingState after failed payment: %v", err) |
| 638 | } |
| 639 | if state.Plan != orgbilling.PlanTeam || state.SubscriptionStatus != orgbilling.SubscriptionStatusPastDue { |
| 640 | t.Fatalf("payment_failed should keep Team plan and mark past_due: %+v", state) |
| 641 | } |
| 642 | if !state.LockedAt.Valid || !state.LockReason.Valid || state.LockReason.BillingLockReason != billingdb.BillingLockReasonPastDue { |
| 643 | t.Fatalf("payment_failed should set past_due lock fields: %+v", state) |
| 644 | } |
| 645 | if !state.GraceUntil.Valid { |
| 646 | t.Fatalf("payment_failed should set grace_until: %+v", state) |
| 647 | } |
| 648 | |
| 649 | current = stripeTestEvent(t, "evt_invoice_paid", "invoice.payment_succeeded", map[string]any{ |
| 650 | "id": "in_test_lifecycle", |
| 651 | "object": "invoice", |
| 652 | "customer": "cus_test_lifecycle", |
| 653 | "status": "paid", |
| 654 | "number": "INV-FAILED", |
| 655 | "currency": "usd", |
| 656 | "amount_due": int64(400), |
| 657 | "amount_paid": int64(400), |
| 658 | "amount_remaining": int64(0), |
| 659 | "period_start": start.Unix(), |
| 660 | "period_end": start.Add(30 * 24 * time.Hour).Unix(), |
| 661 | "status_transitions": map[string]any{"paid_at": start.Add(time.Hour).Unix()}, |
| 662 | "subscription_details": map[string]any{}, |
| 663 | }) |
| 664 | resp = postBillingWebhook(t, mux, "evt_invoice_paid") |
| 665 | if resp.Code != http.StatusOK { |
| 666 | t.Fatalf("payment_succeeded webhook status=%d body=%s", resp.Code, resp.Body.String()) |
| 667 | } |
| 668 | state, err = orgbilling.GetOrgBillingState(ctx, deps, orgID) |
| 669 | if err != nil { |
| 670 | t.Fatalf("GetOrgBillingState after paid invoice: %v", err) |
| 671 | } |
| 672 | if state.Plan != orgbilling.PlanTeam || state.SubscriptionStatus != orgbilling.SubscriptionStatusActive { |
| 673 | t.Fatalf("payment_succeeded should recover active Team state: %+v", state) |
| 674 | } |
| 675 | if state.LockedAt.Valid || state.LockReason.Valid || state.GraceUntil.Valid { |
| 676 | t.Fatalf("payment_succeeded should clear billing action lock: %+v", state) |
| 677 | } |
| 678 | if state.PastDueAt.Valid { |
| 679 | t.Fatalf("payment_succeeded should clear past_due_at: %+v", state) |
| 680 | } |
| 681 | |
| 682 | current = stripeTestEvent(t, "evt_subscription_deleted", "customer.subscription.deleted", map[string]any{ |
| 683 | "id": "sub_test_lifecycle", |
| 684 | "object": "subscription", |
| 685 | "customer": "cus_test_lifecycle", |
| 686 | "status": "canceled", |
| 687 | "cancel_at_period_end": false, |
| 688 | "trial_end": int64(0), |
| 689 | "canceled_at": start.Add(2 * time.Hour).Unix(), |
| 690 | "metadata": map[string]string{stripebilling.MetadataOrgID: strconv.FormatInt(orgID, 10)}, |
| 691 | "items": map[string]any{"data": []map[string]any{{ |
| 692 | "id": "si_test_lifecycle", |
| 693 | "current_period_start": start.Unix(), |
| 694 | "current_period_end": start.Add(30 * 24 * time.Hour).Unix(), |
| 695 | }}}, |
| 696 | }) |
| 697 | resp = postBillingWebhook(t, mux, "evt_subscription_deleted") |
| 698 | if resp.Code != http.StatusOK { |
| 699 | t.Fatalf("subscription deleted webhook status=%d body=%s", resp.Code, resp.Body.String()) |
| 700 | } |
| 701 | state, err = orgbilling.GetOrgBillingState(ctx, deps, orgID) |
| 702 | if err != nil { |
| 703 | t.Fatalf("GetOrgBillingState after cancellation: %v", err) |
| 704 | } |
| 705 | if state.Plan != orgbilling.PlanFree || state.SubscriptionStatus != orgbilling.SubscriptionStatusCanceled { |
| 706 | t.Fatalf("subscription deletion should downgrade to Free canceled state: %+v", state) |
| 707 | } |
| 708 | if !state.LockedAt.Valid || !state.LockReason.Valid || state.LockReason.BillingLockReason != billingdb.BillingLockReasonCanceled { |
| 709 | t.Fatalf("subscription deletion should set canceled lock fields: %+v", state) |
| 710 | } |
| 711 | } |
| 712 | |
| 713 | type fakeStripeRemote struct { |
| 714 | createCustomerFn func(context.Context, stripebilling.CustomerInput) (stripebilling.Customer, error) |
| 715 | createCheckoutFn func(context.Context, stripebilling.CheckoutInput) (stripebilling.CheckoutSession, error) |
| 716 | createPortalFn func(context.Context, stripebilling.PortalInput) (stripebilling.PortalSession, error) |
| 717 | updateQuantityFn func(context.Context, stripebilling.SeatQuantityInput) error |
| 718 | verifyWebhookFn func([]byte, string) (stripeapi.Event, error) |
| 719 | } |
| 720 | |
| 721 | func (f *fakeStripeRemote) CreateCustomer(ctx context.Context, in stripebilling.CustomerInput) (stripebilling.Customer, error) { |
| 722 | if f.createCustomerFn == nil { |
| 723 | return stripebilling.Customer{}, nil |
| 724 | } |
| 725 | return f.createCustomerFn(ctx, in) |
| 726 | } |
| 727 | |
| 728 | func (f *fakeStripeRemote) CreateCheckoutSession(ctx context.Context, in stripebilling.CheckoutInput) (stripebilling.CheckoutSession, error) { |
| 729 | if f.createCheckoutFn == nil { |
| 730 | return stripebilling.CheckoutSession{}, nil |
| 731 | } |
| 732 | return f.createCheckoutFn(ctx, in) |
| 733 | } |
| 734 | |
| 735 | func (f *fakeStripeRemote) CreatePortalSession(ctx context.Context, in stripebilling.PortalInput) (stripebilling.PortalSession, error) { |
| 736 | if f.createPortalFn == nil { |
| 737 | return stripebilling.PortalSession{}, nil |
| 738 | } |
| 739 | return f.createPortalFn(ctx, in) |
| 740 | } |
| 741 | |
| 742 | func (f *fakeStripeRemote) UpdateSubscriptionItemQuantity(ctx context.Context, in stripebilling.SeatQuantityInput) error { |
| 743 | if f.updateQuantityFn == nil { |
| 744 | return nil |
| 745 | } |
| 746 | return f.updateQuantityFn(ctx, in) |
| 747 | } |
| 748 | |
| 749 | func (f *fakeStripeRemote) VerifyWebhook(payload []byte, signatureHeader string) (stripeapi.Event, error) { |
| 750 | if f.verifyWebhookFn == nil { |
| 751 | return stripeapi.Event{}, nil |
| 752 | } |
| 753 | return f.verifyWebhookFn(payload, signatureHeader) |
| 754 | } |
| 755 | |
| 756 | func stripeTestEvent(t *testing.T, id, typ string, raw map[string]any) stripeapi.Event { |
| 757 | t.Helper() |
| 758 | return stripeapi.Event{ |
| 759 | ID: id, |
| 760 | Type: stripeapi.EventType(typ), |
| 761 | APIVersion: "2024-06-20", |
| 762 | Data: &stripeapi.EventData{Raw: mustJSONRaw(t, raw)}, |
| 763 | } |
| 764 | } |
| 765 | |
| 766 | func mustJSONRaw(t *testing.T, v any) []byte { |
| 767 | t.Helper() |
| 768 | raw, err := json.Marshal(v) |
| 769 | if err != nil { |
| 770 | t.Fatalf("marshal stripe raw object: %v", err) |
| 771 | } |
| 772 | return raw |
| 773 | } |
| 774 | |
| 775 | func postBillingWebhook(t *testing.T, mux *chi.Mux, eventID string) *httptest.ResponseRecorder { |
| 776 | t.Helper() |
| 777 | req := httptest.NewRequest(http.MethodPost, "/stripe/webhook", strings.NewReader(`{"id":"`+eventID+`"}`)) |
| 778 | req.Header.Set("Stripe-Signature", "sig_test") |
| 779 | resp := httptest.NewRecorder() |
| 780 | mux.ServeHTTP(resp, req) |
| 781 | return resp |
| 782 | } |
| 783 | |
| 784 | func newOrgBillingMux(t *testing.T, pool *pgxpool.Pool, ownerID int64, remote stripebilling.Remote) *chi.Mux { |
| 785 | t.Helper() |
| 786 | return newOrgBillingMuxForUser(t, pool, middleware.CurrentUser{ID: ownerID, Username: "owner"}, remote) |
| 787 | } |
| 788 | |
| 789 | // newOrgBillingMuxWithPrices is the price-aware variant used by the |
| 790 | // PRO08 cross-kind guard tests. Default mux leaves Team / Pro price |
| 791 | // IDs empty (guard short-circuits); these tests need them populated |
| 792 | // so the guard actually exercises its logic. |
| 793 | func newOrgBillingMuxWithPrices(t *testing.T, pool *pgxpool.Pool, ownerID int64, remote stripebilling.Remote, teamPriceID, proPriceID string) *chi.Mux { |
| 794 | t.Helper() |
| 795 | return newOrgBillingMuxFull(t, pool, middleware.CurrentUser{ID: ownerID, Username: "owner"}, remote, teamPriceID, proPriceID) |
| 796 | } |
| 797 | |
| 798 | func newOrgBillingMuxForUser(t *testing.T, pool *pgxpool.Pool, viewer middleware.CurrentUser, remote stripebilling.Remote) *chi.Mux { |
| 799 | t.Helper() |
| 800 | return newOrgBillingMuxFull(t, pool, viewer, remote, "", "") |
| 801 | } |
| 802 | |
| 803 | func newOrgBillingMuxFull(t *testing.T, pool *pgxpool.Pool, viewer middleware.CurrentUser, remote stripebilling.Remote, teamPriceID, proPriceID string) *chi.Mux { |
| 804 | t.Helper() |
| 805 | tmplFS := fstest.MapFS{ |
| 806 | "_layout.html": {Data: []byte(`{{ define "layout" }}<html><body>{{ template "page" . }}</body></html>{{ end }}`)}, |
| 807 | "orgs/billing_result.html": {Data: []byte(`{{ define "page" }}RESULT={{ .Result }};HEADING={{ .Heading }};MESSAGE={{ .Message }};BILLING={{ .BillingPath }}{{ end }}`)}, |
| 808 | "orgs/settings_billing.html": {Data: []byte(`{{ define "page" }}{{ with .Error }}ERROR={{ . }}{{ end }}{{ with .Notice }}NOTICE={{ . }}{{ end }}{{ with .BillingAlert }}{{ if .Message }}ALERT={{ .Message }}{{ end }}{{ end }}{{ with .Usage.Alert }}{{ if .Message }}USAGE_ALERT={{ .Message }};{{ end }}{{ end }}SEATS={{ .Seats.ActiveMembers }}/{{ .Seats.BillableSeats }}/{{ .Seats.PendingInvites }};{{ range .Usage.Rows }}USAGE={{ .Key }}:{{ .UsedLabel }}/{{ .LimitLabel }}/{{ .PercentLabel }}/{{ .StatusClass }};{{ end }}{{ range .Summary }}{{ if eq .Label "Payment source" }}PAYMENT={{ .Detail }};{{ end }}{{ end }}{{ if .IsSiteAdmin }}DEBUG={{ .Debug.StripeCustomerID }}|{{ .Debug.StripeSubscriptionID }}|{{ .Debug.StripeSubscriptionItemID }}|{{ .Debug.LastWebhookEventID }}|{{ .Debug.LastWebhookStatus }};{{ range .Debug.QuotaOverrides }}OVERRIDE={{ .Kind }}:{{ .Limit }};{{ end }}{{ end }}{{ range .Invoices }}INVOICE={{ .Number }};{{ end }}{{ end }}`)}, |
| 809 | "errors/403.html": {Data: []byte(`{{ define "page" }}403{{ end }}`)}, |
| 810 | "errors/404.html": {Data: []byte(`{{ define "page" }}404{{ end }}`)}, |
| 811 | "errors/500.html": {Data: []byte(`{{ define "page" }}500{{ end }}`)}, |
| 812 | } |
| 813 | rr, err := render.New(tmplFS, render.Options{}) |
| 814 | if err != nil { |
| 815 | t.Fatalf("render.New: %v", err) |
| 816 | } |
| 817 | h, err := orgsh.New(orgsh.Deps{ |
| 818 | Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), |
| 819 | Render: rr, |
| 820 | Pool: pool, |
| 821 | BaseURL: "https://shithub.example", |
| 822 | BillingEnabled: true, |
| 823 | BillingGracePeriod: 14 * 24 * time.Hour, |
| 824 | Stripe: remote, |
| 825 | StripeSuccessURL: "https://shithub.example/organizations/{org}/billing/success", |
| 826 | StripeCancelURL: "https://shithub.example/organizations/{org}/billing/cancel", |
| 827 | StripePortalReturnURL: "https://shithub.example/organizations/{org}/settings/billing", |
| 828 | StripeTeamPriceID: teamPriceID, |
| 829 | StripeProPriceID: proPriceID, |
| 830 | }) |
| 831 | if err != nil { |
| 832 | t.Fatalf("orgsh.New: %v", err) |
| 833 | } |
| 834 | mux := chi.NewRouter() |
| 835 | mux.Use(func(next http.Handler) http.Handler { |
| 836 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| 837 | next.ServeHTTP(w, r.WithContext(middleware.WithCurrentUserForTest(r.Context(), viewer))) |
| 838 | }) |
| 839 | }) |
| 840 | h.MountCreate(mux) |
| 841 | h.MountBillingWebhook(mux) |
| 842 | return mux |
| 843 | } |
| 844 | |
| 845 | func insertBillingPendingInvitation(t *testing.T, db *pgxpool.Pool, orgID int64, email string, token []byte) { |
| 846 | t.Helper() |
| 847 | if _, err := db.Exec(context.Background(), ` |
| 848 | INSERT INTO org_invitations (org_id, target_email, role, token_hash, expires_at) |
| 849 | VALUES ($1, $2, 'member', $3, now() + interval '1 day') |
| 850 | `, orgID, email, token); err != nil { |
| 851 | t.Fatalf("insert pending billing invitation: %v", err) |
| 852 | } |
| 853 | } |
| 854 |