Sync org billing seats asynchronously
Authored by
mfwolffe <wolffemf@dukes.jmu.edu>
- SHA
f47e06238e65af358f347c0b4508dec79389c802- Parents
-
3c68bbc - Tree
c900332
f47e062
f47e06238e65af358f347c0b4508dec79389c8023c68bbc
c900332| Status | File | + | - |
|---|---|---|---|
| M |
cmd/shithubd/worker.go
|
17 | 0 |
| A |
internal/orgs/billing_jobs.go
|
23 | 0 |
| A |
internal/orgs/billing_jobs_test.go
|
92 | 0 |
| M |
internal/orgs/create.go
|
3 | 0 |
| M |
internal/orgs/invitations.go
|
3 | 0 |
| M |
internal/orgs/members.go
|
47 | 6 |
| A |
internal/worker/jobs/org_billing_seat_sync.go
|
102 | 0 |
| A |
internal/worker/jobs/org_billing_seat_sync_test.go
|
167 | 0 |
| M |
internal/worker/types.go
|
7 | 0 |
cmd/shithubd/worker.gomodified@@ -24,6 +24,7 @@ import ( | ||
| 24 | 24 | "github.com/tenseleyFlow/shithub/internal/auth/email" |
| 25 | 25 | "github.com/tenseleyFlow/shithub/internal/auth/secretbox" |
| 26 | 26 | "github.com/tenseleyFlow/shithub/internal/auth/throttle" |
| 27 | + "github.com/tenseleyFlow/shithub/internal/billing/stripebilling" | |
| 27 | 28 | "github.com/tenseleyFlow/shithub/internal/infra/config" |
| 28 | 29 | "github.com/tenseleyFlow/shithub/internal/infra/db" |
| 29 | 30 | "github.com/tenseleyFlow/shithub/internal/infra/storage" |
@@ -92,6 +93,19 @@ var workerCmd = &cobra.Command{ | ||
| 92 | 93 | "hint", "set Auth.TOTPKeyB64 to a base64 32-byte key", |
| 93 | 94 | "error", boxErr) |
| 94 | 95 | } |
| 96 | + var stripeRemote stripebilling.Remote | |
| 97 | + if cfg.Billing.Enabled { | |
| 98 | + remote, err := stripebilling.New(stripebilling.Config{ | |
| 99 | + SecretKey: cfg.Billing.Stripe.SecretKey, | |
| 100 | + WebhookSecret: cfg.Billing.Stripe.WebhookSecret, | |
| 101 | + TeamPriceID: cfg.Billing.Stripe.TeamPriceID, | |
| 102 | + AutomaticTax: cfg.Billing.Stripe.AutomaticTax, | |
| 103 | + }) | |
| 104 | + if err != nil { | |
| 105 | + return fmt.Errorf("billing: %w", err) | |
| 106 | + } | |
| 107 | + stripeRemote = remote | |
| 108 | + } | |
| 95 | 109 | |
| 96 | 110 | p := worker.NewPool(pool, worker.PoolConfig{ |
| 97 | 111 | Workers: count, |
@@ -132,6 +146,9 @@ var workerCmd = &cobra.Command{ | ||
| 132 | 146 | } |
| 133 | 147 | p.Register(worker.KindOrgGitHubImportDiscover, jobs.OrgGitHubImportDiscover(importDeps)) |
| 134 | 148 | p.Register(worker.KindOrgGitHubImportRepo, jobs.OrgGitHubImportRepo(importDeps)) |
| 149 | + p.Register(worker.KindOrgBillingSeatSync, jobs.OrgBillingSeatSync(jobs.OrgBillingSeatSyncDeps{ | |
| 150 | + Pool: pool, Logger: logger, Stripe: stripeRemote, | |
| 151 | + })) | |
| 135 | 152 | |
| 136 | 153 | notifSender, _ := pickNotifEmailSender(cfg) |
| 137 | 154 | p.Register(worker.KindNotifyFanout, jobs.NotifyFanout(jobs.NotifyFanoutDeps{ |
internal/orgs/billing_jobs.goadded@@ -0,0 +1,23 @@ | ||
| 1 | +// SPDX-License-Identifier: AGPL-3.0-or-later | |
| 2 | + | |
| 3 | +package orgs | |
| 4 | + | |
| 5 | +import ( | |
| 6 | + "context" | |
| 7 | + | |
| 8 | + "github.com/jackc/pgx/v5" | |
| 9 | + | |
| 10 | + "github.com/tenseleyFlow/shithub/internal/worker" | |
| 11 | +) | |
| 12 | + | |
| 13 | +func enqueueBillingSeatSync(ctx context.Context, tx pgx.Tx, deps Deps, orgID int64) error { | |
| 14 | + if _, err := worker.Enqueue(ctx, tx, worker.KindOrgBillingSeatSync, map[string]any{ | |
| 15 | + "org_id": orgID, | |
| 16 | + }, worker.EnqueueOptions{}); err != nil { | |
| 17 | + return err | |
| 18 | + } | |
| 19 | + if err := worker.Notify(ctx, tx); err != nil && deps.Logger != nil { | |
| 20 | + deps.Logger.WarnContext(ctx, "org billing: notify seat sync", "error", err, "org_id", orgID) | |
| 21 | + } | |
| 22 | + return nil | |
| 23 | +} | |
internal/orgs/billing_jobs_test.goadded@@ -0,0 +1,92 @@ | ||
| 1 | +// SPDX-License-Identifier: AGPL-3.0-or-later | |
| 2 | + | |
| 3 | +package orgs_test | |
| 4 | + | |
| 5 | +import ( | |
| 6 | + "context" | |
| 7 | + "strconv" | |
| 8 | + "testing" | |
| 9 | + | |
| 10 | + "github.com/jackc/pgx/v5/pgxpool" | |
| 11 | + | |
| 12 | + "github.com/tenseleyFlow/shithub/internal/orgs" | |
| 13 | + "github.com/tenseleyFlow/shithub/internal/worker" | |
| 14 | +) | |
| 15 | + | |
| 16 | +func TestCreateEnqueuesBillingSeatSync(t *testing.T) { | |
| 17 | + t.Parallel() | |
| 18 | + pool, deps, alice := setup(t) | |
| 19 | + | |
| 20 | + org, err := orgs.Create(context.Background(), deps, orgs.CreateParams{ | |
| 21 | + Slug: "acme", DisplayName: "Acme Inc", CreatedByUserID: alice, | |
| 22 | + }) | |
| 23 | + if err != nil { | |
| 24 | + t.Fatalf("create org: %v", err) | |
| 25 | + } | |
| 26 | + if got := countBillingSeatSyncJobs(t, pool, org.ID); got != 1 { | |
| 27 | + t.Fatalf("billing seat sync jobs=%d, want 1", got) | |
| 28 | + } | |
| 29 | +} | |
| 30 | + | |
| 31 | +func TestMemberChangesEnqueueBillingSeatSync(t *testing.T) { | |
| 32 | + t.Parallel() | |
| 33 | + pool, deps, alice := setup(t) | |
| 34 | + bob := mustUser(t, pool, "bob") | |
| 35 | + | |
| 36 | + org, err := orgs.Create(context.Background(), deps, orgs.CreateParams{ | |
| 37 | + Slug: "acme", DisplayName: "Acme Inc", CreatedByUserID: alice, | |
| 38 | + }) | |
| 39 | + if err != nil { | |
| 40 | + t.Fatalf("create org: %v", err) | |
| 41 | + } | |
| 42 | + if err := orgs.AddMember(context.Background(), deps, org.ID, bob, alice, "member"); err != nil { | |
| 43 | + t.Fatalf("AddMember: %v", err) | |
| 44 | + } | |
| 45 | + if got := countBillingSeatSyncJobs(t, pool, org.ID); got != 2 { | |
| 46 | + t.Fatalf("billing seat sync jobs after add=%d, want 2", got) | |
| 47 | + } | |
| 48 | + if err := orgs.RemoveMember(context.Background(), deps, org.ID, bob); err != nil { | |
| 49 | + t.Fatalf("RemoveMember: %v", err) | |
| 50 | + } | |
| 51 | + if got := countBillingSeatSyncJobs(t, pool, org.ID); got != 3 { | |
| 52 | + t.Fatalf("billing seat sync jobs after remove=%d, want 3", got) | |
| 53 | + } | |
| 54 | +} | |
| 55 | + | |
| 56 | +func TestAcceptInvitationEnqueuesBillingSeatSync(t *testing.T) { | |
| 57 | + t.Parallel() | |
| 58 | + pool, deps, alice := setup(t) | |
| 59 | + bob := mustUser(t, pool, "bob") | |
| 60 | + | |
| 61 | + org, err := orgs.Create(context.Background(), deps, orgs.CreateParams{ | |
| 62 | + Slug: "acme", DisplayName: "Acme Inc", CreatedByUserID: alice, | |
| 63 | + }) | |
| 64 | + if err != nil { | |
| 65 | + t.Fatalf("create org: %v", err) | |
| 66 | + } | |
| 67 | + res, err := orgs.Invite(context.Background(), deps, orgs.InviteParams{ | |
| 68 | + OrgID: org.ID, InvitedByUserID: alice, | |
| 69 | + TargetUsername: "bob", Role: "member", | |
| 70 | + }) | |
| 71 | + if err != nil { | |
| 72 | + t.Fatalf("Invite: %v", err) | |
| 73 | + } | |
| 74 | + if err := orgs.AcceptInvitation(context.Background(), deps, res.Invitation, bob); err != nil { | |
| 75 | + t.Fatalf("AcceptInvitation: %v", err) | |
| 76 | + } | |
| 77 | + if got := countBillingSeatSyncJobs(t, pool, org.ID); got != 2 { | |
| 78 | + t.Fatalf("billing seat sync jobs after accept=%d, want 2", got) | |
| 79 | + } | |
| 80 | +} | |
| 81 | + | |
| 82 | +func countBillingSeatSyncJobs(t *testing.T, pool *pgxpool.Pool, orgID int64) int { | |
| 83 | + t.Helper() | |
| 84 | + var jobs int | |
| 85 | + if err := pool.QueryRow(context.Background(), | |
| 86 | + `SELECT count(*) FROM jobs WHERE kind = $1 AND payload->>'org_id' = $2`, | |
| 87 | + worker.KindOrgBillingSeatSync, strconv.FormatInt(orgID, 10), | |
| 88 | + ).Scan(&jobs); err != nil { | |
| 89 | + t.Fatalf("query billing seat sync jobs: %v", err) | |
| 90 | + } | |
| 91 | + return jobs | |
| 92 | +} | |
internal/orgs/create.gomodified@@ -93,6 +93,9 @@ func Create(ctx context.Context, deps Deps, p CreateParams) (orgsdb.Org, error) | ||
| 93 | 93 | }); err != nil { |
| 94 | 94 | return orgsdb.Org{}, fmt.Errorf("seed owner: %w", err) |
| 95 | 95 | } |
| 96 | + if err := enqueueBillingSeatSync(ctx, tx, deps, row.ID); err != nil { | |
| 97 | + return orgsdb.Org{}, fmt.Errorf("enqueue billing seat sync: %w", err) | |
| 98 | + } | |
| 96 | 99 | |
| 97 | 100 | if err := tx.Commit(ctx); err != nil { |
| 98 | 101 | return orgsdb.Org{}, fmt.Errorf("commit: %w", err) |
internal/orgs/invitations.gomodified@@ -209,6 +209,9 @@ func AcceptInvitation(ctx context.Context, deps Deps, inv orgsdb.OrgInvitation, | ||
| 209 | 209 | if err := q.AcceptOrgInvitation(ctx, tx, inv.ID); err != nil { |
| 210 | 210 | return fmt.Errorf("mark accepted: %w", err) |
| 211 | 211 | } |
| 212 | + if err := enqueueBillingSeatSync(ctx, tx, deps, inv.OrgID); err != nil { | |
| 213 | + return fmt.Errorf("enqueue billing seat sync: %w", err) | |
| 214 | + } | |
| 212 | 215 | if err := tx.Commit(ctx); err != nil { |
| 213 | 216 | return err |
| 214 | 217 | } |
internal/orgs/members.gomodified@@ -26,12 +26,33 @@ func AddMember(ctx context.Context, deps Deps, orgID, userID, invitedByUserID in | ||
| 26 | 26 | if err != nil { |
| 27 | 27 | return err |
| 28 | 28 | } |
| 29 | - return orgsdb.New().AddOrgMember(ctx, deps.Pool, orgsdb.AddOrgMemberParams{ | |
| 29 | + tx, err := deps.Pool.Begin(ctx) | |
| 30 | + if err != nil { | |
| 31 | + return err | |
| 32 | + } | |
| 33 | + committed := false | |
| 34 | + defer func() { | |
| 35 | + if !committed { | |
| 36 | + _ = tx.Rollback(ctx) | |
| 37 | + } | |
| 38 | + }() | |
| 39 | + q := orgsdb.New() | |
| 40 | + if err := q.AddOrgMember(ctx, tx, orgsdb.AddOrgMemberParams{ | |
| 30 | 41 | OrgID: orgID, |
| 31 | 42 | UserID: userID, |
| 32 | 43 | Role: r, |
| 33 | 44 | InvitedByUserID: pgtype.Int8{Int64: invitedByUserID, Valid: invitedByUserID != 0}, |
| 34 | - }) | |
| 45 | + }); err != nil { | |
| 46 | + return err | |
| 47 | + } | |
| 48 | + if err := enqueueBillingSeatSync(ctx, tx, deps, orgID); err != nil { | |
| 49 | + return fmt.Errorf("enqueue billing seat sync: %w", err) | |
| 50 | + } | |
| 51 | + if err := tx.Commit(ctx); err != nil { | |
| 52 | + return err | |
| 53 | + } | |
| 54 | + committed = true | |
| 55 | + return nil | |
| 35 | 56 | } |
| 36 | 57 | |
| 37 | 58 | // ChangeRole updates a member's role with last-owner protection: the |
@@ -71,8 +92,18 @@ func ChangeRole(ctx context.Context, deps Deps, orgID, userID int64, role string | ||
| 71 | 92 | // applies — we refuse to drop the only owner. Removing oneself is |
| 72 | 93 | // fine when there are ≥2 owners. |
| 73 | 94 | func RemoveMember(ctx context.Context, deps Deps, orgID, userID int64) error { |
| 95 | + tx, err := deps.Pool.Begin(ctx) | |
| 96 | + if err != nil { | |
| 97 | + return err | |
| 98 | + } | |
| 99 | + committed := false | |
| 100 | + defer func() { | |
| 101 | + if !committed { | |
| 102 | + _ = tx.Rollback(ctx) | |
| 103 | + } | |
| 104 | + }() | |
| 74 | 105 | q := orgsdb.New() |
| 75 | - current, err := q.GetOrgMember(ctx, deps.Pool, orgsdb.GetOrgMemberParams{ | |
| 106 | + current, err := q.GetOrgMember(ctx, tx, orgsdb.GetOrgMemberParams{ | |
| 76 | 107 | OrgID: orgID, UserID: userID, |
| 77 | 108 | }) |
| 78 | 109 | if err != nil { |
@@ -82,7 +113,7 @@ func RemoveMember(ctx context.Context, deps Deps, orgID, userID int64) error { | ||
| 82 | 113 | return err |
| 83 | 114 | } |
| 84 | 115 | if current.Role == orgsdb.OrgRoleOwner { |
| 85 | - count, err := q.CountOrgOwners(ctx, deps.Pool, orgID) | |
| 116 | + count, err := q.CountOrgOwners(ctx, tx, orgID) | |
| 86 | 117 | if err != nil { |
| 87 | 118 | return err |
| 88 | 119 | } |
@@ -90,9 +121,19 @@ func RemoveMember(ctx context.Context, deps Deps, orgID, userID int64) error { | ||
| 90 | 121 | return ErrLastOwner |
| 91 | 122 | } |
| 92 | 123 | } |
| 93 | - return q.RemoveOrgMember(ctx, deps.Pool, orgsdb.RemoveOrgMemberParams{ | |
| 124 | + if err := q.RemoveOrgMember(ctx, tx, orgsdb.RemoveOrgMemberParams{ | |
| 94 | 125 | OrgID: orgID, UserID: userID, |
| 95 | - }) | |
| 126 | + }); err != nil { | |
| 127 | + return err | |
| 128 | + } | |
| 129 | + if err := enqueueBillingSeatSync(ctx, tx, deps, orgID); err != nil { | |
| 130 | + return fmt.Errorf("enqueue billing seat sync: %w", err) | |
| 131 | + } | |
| 132 | + if err := tx.Commit(ctx); err != nil { | |
| 133 | + return err | |
| 134 | + } | |
| 135 | + committed = true | |
| 136 | + return nil | |
| 96 | 137 | } |
| 97 | 138 | |
| 98 | 139 | // IsMember reports whether the user is a member of the org. Used by |
internal/worker/jobs/org_billing_seat_sync.goadded@@ -0,0 +1,102 @@ | ||
| 1 | +// SPDX-License-Identifier: AGPL-3.0-or-later | |
| 2 | + | |
| 3 | +package jobs | |
| 4 | + | |
| 5 | +import ( | |
| 6 | + "context" | |
| 7 | + "encoding/json" | |
| 8 | + "errors" | |
| 9 | + "fmt" | |
| 10 | + "log/slog" | |
| 11 | + | |
| 12 | + "github.com/jackc/pgx/v5" | |
| 13 | + "github.com/jackc/pgx/v5/pgxpool" | |
| 14 | + | |
| 15 | + orgbilling "github.com/tenseleyFlow/shithub/internal/billing" | |
| 16 | + "github.com/tenseleyFlow/shithub/internal/billing/stripebilling" | |
| 17 | + "github.com/tenseleyFlow/shithub/internal/worker" | |
| 18 | +) | |
| 19 | + | |
| 20 | +type OrgBillingSeatSyncDeps struct { | |
| 21 | + Pool *pgxpool.Pool | |
| 22 | + Logger *slog.Logger | |
| 23 | + Stripe stripebilling.Remote | |
| 24 | +} | |
| 25 | + | |
| 26 | +type OrgBillingSeatSyncPayload struct { | |
| 27 | + OrgID int64 `json:"org_id"` | |
| 28 | +} | |
| 29 | + | |
| 30 | +func OrgBillingSeatSync(deps OrgBillingSeatSyncDeps) worker.Handler { | |
| 31 | + return func(ctx context.Context, raw json.RawMessage) error { | |
| 32 | + var p OrgBillingSeatSyncPayload | |
| 33 | + if err := json.Unmarshal(raw, &p); err != nil { | |
| 34 | + return worker.PoisonError(fmt.Errorf("bad payload: %w", err)) | |
| 35 | + } | |
| 36 | + if p.OrgID == 0 { | |
| 37 | + return worker.PoisonError(errors.New("missing org_id")) | |
| 38 | + } | |
| 39 | + | |
| 40 | + bdeps := orgbilling.Deps{Pool: deps.Pool} | |
| 41 | + state, err := orgbilling.GetOrgBillingState(ctx, bdeps, p.OrgID) | |
| 42 | + if err != nil { | |
| 43 | + if errors.Is(err, pgx.ErrNoRows) { | |
| 44 | + if deps.Logger != nil { | |
| 45 | + deps.Logger.InfoContext(ctx, "org billing seat sync skipped; billing state missing", | |
| 46 | + "org_id", p.OrgID) | |
| 47 | + } | |
| 48 | + return nil | |
| 49 | + } | |
| 50 | + return fmt.Errorf("load billing state: %w", err) | |
| 51 | + } | |
| 52 | + | |
| 53 | + members, err := orgbilling.CountBillableOrgMembers(ctx, bdeps, p.OrgID) | |
| 54 | + if err != nil { | |
| 55 | + return fmt.Errorf("count billable members: %w", err) | |
| 56 | + } | |
| 57 | + if _, err := orgbilling.SyncSeatSnapshot(ctx, bdeps, orgbilling.SeatSnapshot{ | |
| 58 | + OrgID: p.OrgID, | |
| 59 | + StripeSubscriptionID: state.StripeSubscriptionID.String, | |
| 60 | + ActiveMembers: members, | |
| 61 | + BillableSeats: members, | |
| 62 | + Source: "worker", | |
| 63 | + }); err != nil { | |
| 64 | + return fmt.Errorf("sync seat snapshot: %w", err) | |
| 65 | + } | |
| 66 | + | |
| 67 | + if deps.Stripe == nil || !shouldSyncStripeSeatQuantity(state) { | |
| 68 | + return nil | |
| 69 | + } | |
| 70 | + if err := deps.Stripe.UpdateSubscriptionItemQuantity(ctx, stripebilling.SeatQuantityInput{ | |
| 71 | + OrgID: p.OrgID, | |
| 72 | + SubscriptionItemID: state.StripeSubscriptionItemID.String, | |
| 73 | + Quantity: int64(members), | |
| 74 | + }); err != nil { | |
| 75 | + return fmt.Errorf("update stripe subscription item quantity: %w", err) | |
| 76 | + } | |
| 77 | + if deps.Logger != nil { | |
| 78 | + deps.Logger.InfoContext(ctx, "org billing seat sync updated subscription quantity", | |
| 79 | + "org_id", p.OrgID, | |
| 80 | + "seats", members, | |
| 81 | + "subscription_item_id", state.StripeSubscriptionItemID.String) | |
| 82 | + } | |
| 83 | + return nil | |
| 84 | + } | |
| 85 | +} | |
| 86 | + | |
| 87 | +func shouldSyncStripeSeatQuantity(state orgbilling.State) bool { | |
| 88 | + if !state.StripeSubscriptionItemID.Valid { | |
| 89 | + return false | |
| 90 | + } | |
| 91 | + switch state.SubscriptionStatus { | |
| 92 | + case orgbilling.SubscriptionStatusActive, | |
| 93 | + orgbilling.SubscriptionStatusTrialing, | |
| 94 | + orgbilling.SubscriptionStatusIncomplete, | |
| 95 | + orgbilling.SubscriptionStatusPastDue, | |
| 96 | + orgbilling.SubscriptionStatusUnpaid, | |
| 97 | + orgbilling.SubscriptionStatusPaused: | |
| 98 | + return true | |
| 99 | + default: | |
| 100 | + return false | |
| 101 | + } | |
| 102 | +} | |
internal/worker/jobs/org_billing_seat_sync_test.goadded@@ -0,0 +1,167 @@ | ||
| 1 | +// SPDX-License-Identifier: AGPL-3.0-or-later | |
| 2 | + | |
| 3 | +package jobs_test | |
| 4 | + | |
| 5 | +import ( | |
| 6 | + "context" | |
| 7 | + "encoding/json" | |
| 8 | + "errors" | |
| 9 | + "io" | |
| 10 | + "log/slog" | |
| 11 | + "testing" | |
| 12 | + "time" | |
| 13 | + | |
| 14 | + "github.com/jackc/pgx/v5/pgxpool" | |
| 15 | + stripeapi "github.com/stripe/stripe-go/v85" | |
| 16 | + | |
| 17 | + "github.com/tenseleyFlow/shithub/internal/billing" | |
| 18 | + "github.com/tenseleyFlow/shithub/internal/billing/stripebilling" | |
| 19 | + "github.com/tenseleyFlow/shithub/internal/orgs" | |
| 20 | + "github.com/tenseleyFlow/shithub/internal/testing/dbtest" | |
| 21 | + usersdb "github.com/tenseleyFlow/shithub/internal/users/sqlc" | |
| 22 | + "github.com/tenseleyFlow/shithub/internal/worker/jobs" | |
| 23 | +) | |
| 24 | + | |
| 25 | +const billingFixtureHash = "$argon2id$v=19$m=16384,t=1,p=1$" + | |
| 26 | + "AAAAAAAAAAAAAAAA$" + | |
| 27 | + "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" | |
| 28 | + | |
| 29 | +func TestOrgBillingSeatSyncUpdatesStateAndStripeQuantity(t *testing.T) { | |
| 30 | + t.Parallel() | |
| 31 | + ctx := context.Background() | |
| 32 | + pool, orgID := setupOrgBillingSeatSync(t) | |
| 33 | + memberID := createBillingUser(t, pool, "bob") | |
| 34 | + if err := orgs.AddMember(ctx, orgs.Deps{Pool: pool, Logger: discardLogger()}, orgID, memberID, 0, "member"); err != nil { | |
| 35 | + t.Fatalf("AddMember: %v", err) | |
| 36 | + } | |
| 37 | + if _, err := billing.SetStripeCustomer(ctx, billing.Deps{Pool: pool}, orgID, "cus_test"); err != nil { | |
| 38 | + t.Fatalf("SetStripeCustomer: %v", err) | |
| 39 | + } | |
| 40 | + start := time.Now().UTC().Truncate(time.Second) | |
| 41 | + if _, err := billing.ApplySubscriptionSnapshot(ctx, billing.Deps{Pool: pool}, billing.SubscriptionSnapshot{ | |
| 42 | + OrgID: orgID, | |
| 43 | + Plan: billing.PlanTeam, | |
| 44 | + Status: billing.SubscriptionStatusActive, | |
| 45 | + StripeSubscriptionID: "sub_test", | |
| 46 | + StripeSubscriptionItemID: "si_test", | |
| 47 | + CurrentPeriodStart: start, | |
| 48 | + CurrentPeriodEnd: start.Add(30 * 24 * time.Hour), | |
| 49 | + LastWebhookEventID: "evt_test", | |
| 50 | + }); err != nil { | |
| 51 | + t.Fatalf("ApplySubscriptionSnapshot: %v", err) | |
| 52 | + } | |
| 53 | + | |
| 54 | + var got stripebilling.SeatQuantityInput | |
| 55 | + handler := jobs.OrgBillingSeatSync(jobs.OrgBillingSeatSyncDeps{ | |
| 56 | + Pool: pool, | |
| 57 | + Logger: discardLogger(), | |
| 58 | + Stripe: &fakeSeatSyncStripeRemote{ | |
| 59 | + updateQuantityFn: func(_ context.Context, in stripebilling.SeatQuantityInput) error { | |
| 60 | + got = in | |
| 61 | + return nil | |
| 62 | + }, | |
| 63 | + }, | |
| 64 | + }) | |
| 65 | + payload, _ := json.Marshal(jobs.OrgBillingSeatSyncPayload{OrgID: orgID}) | |
| 66 | + if err := handler(ctx, payload); err != nil { | |
| 67 | + t.Fatalf("OrgBillingSeatSync: %v", err) | |
| 68 | + } | |
| 69 | + | |
| 70 | + state, err := billing.GetOrgBillingState(ctx, billing.Deps{Pool: pool}, orgID) | |
| 71 | + if err != nil { | |
| 72 | + t.Fatalf("GetOrgBillingState: %v", err) | |
| 73 | + } | |
| 74 | + if state.BillableSeats != 2 || !state.SeatSnapshotAt.Valid { | |
| 75 | + t.Fatalf("seat snapshot not reflected in state: %+v", state) | |
| 76 | + } | |
| 77 | + if got.OrgID != orgID || got.SubscriptionItemID != "si_test" || got.Quantity != 2 { | |
| 78 | + t.Fatalf("unexpected stripe quantity update: %+v", got) | |
| 79 | + } | |
| 80 | +} | |
| 81 | + | |
| 82 | +func TestOrgBillingSeatSyncSkipsStripeForFreeOrg(t *testing.T) { | |
| 83 | + t.Parallel() | |
| 84 | + ctx := context.Background() | |
| 85 | + pool, orgID := setupOrgBillingSeatSync(t) | |
| 86 | + called := false | |
| 87 | + handler := jobs.OrgBillingSeatSync(jobs.OrgBillingSeatSyncDeps{ | |
| 88 | + Pool: pool, | |
| 89 | + Logger: discardLogger(), | |
| 90 | + Stripe: &fakeSeatSyncStripeRemote{ | |
| 91 | + updateQuantityFn: func(_ context.Context, _ stripebilling.SeatQuantityInput) error { | |
| 92 | + called = true | |
| 93 | + return nil | |
| 94 | + }, | |
| 95 | + }, | |
| 96 | + }) | |
| 97 | + payload, _ := json.Marshal(jobs.OrgBillingSeatSyncPayload{OrgID: orgID}) | |
| 98 | + if err := handler(ctx, payload); err != nil { | |
| 99 | + t.Fatalf("OrgBillingSeatSync: %v", err) | |
| 100 | + } | |
| 101 | + if called { | |
| 102 | + t.Fatal("expected free org seat sync to skip Stripe quantity update") | |
| 103 | + } | |
| 104 | + state, err := billing.GetOrgBillingState(ctx, billing.Deps{Pool: pool}, orgID) | |
| 105 | + if err != nil { | |
| 106 | + t.Fatalf("GetOrgBillingState: %v", err) | |
| 107 | + } | |
| 108 | + if state.BillableSeats != 1 || !state.SeatSnapshotAt.Valid { | |
| 109 | + t.Fatalf("free org seat snapshot not recorded: %+v", state) | |
| 110 | + } | |
| 111 | +} | |
| 112 | + | |
| 113 | +func setupOrgBillingSeatSync(t *testing.T) (*pgxpool.Pool, int64) { | |
| 114 | + t.Helper() | |
| 115 | + pool := dbtest.NewTestDB(t) | |
| 116 | + ctx := context.Background() | |
| 117 | + ownerID := createBillingUser(t, pool, "owner") | |
| 118 | + org, err := orgs.Create(ctx, orgs.Deps{Pool: pool, Logger: discardLogger()}, orgs.CreateParams{ | |
| 119 | + Slug: "acme", DisplayName: "Acme Inc", CreatedByUserID: ownerID, | |
| 120 | + }) | |
| 121 | + if err != nil { | |
| 122 | + t.Fatalf("orgs.Create: %v", err) | |
| 123 | + } | |
| 124 | + return pool, org.ID | |
| 125 | +} | |
| 126 | + | |
| 127 | +func createBillingUser(t *testing.T, pool *pgxpool.Pool, username string) int64 { | |
| 128 | + t.Helper() | |
| 129 | + user, err := usersdb.New().CreateUser(context.Background(), pool, usersdb.CreateUserParams{ | |
| 130 | + Username: username, DisplayName: username, PasswordHash: billingFixtureHash, | |
| 131 | + }) | |
| 132 | + if err != nil { | |
| 133 | + t.Fatalf("CreateUser(%s): %v", username, err) | |
| 134 | + } | |
| 135 | + return user.ID | |
| 136 | +} | |
| 137 | + | |
| 138 | +func discardLogger() *slog.Logger { | |
| 139 | + return slog.New(slog.NewTextHandler(io.Discard, nil)) | |
| 140 | +} | |
| 141 | + | |
| 142 | +type fakeSeatSyncStripeRemote struct { | |
| 143 | + updateQuantityFn func(context.Context, stripebilling.SeatQuantityInput) error | |
| 144 | +} | |
| 145 | + | |
| 146 | +func (f *fakeSeatSyncStripeRemote) CreateCustomer(context.Context, stripebilling.CustomerInput) (stripebilling.Customer, error) { | |
| 147 | + return stripebilling.Customer{}, errors.New("unexpected CreateCustomer call") | |
| 148 | +} | |
| 149 | + | |
| 150 | +func (f *fakeSeatSyncStripeRemote) CreateCheckoutSession(context.Context, stripebilling.CheckoutInput) (stripebilling.CheckoutSession, error) { | |
| 151 | + return stripebilling.CheckoutSession{}, errors.New("unexpected CreateCheckoutSession call") | |
| 152 | +} | |
| 153 | + | |
| 154 | +func (f *fakeSeatSyncStripeRemote) CreatePortalSession(context.Context, stripebilling.PortalInput) (stripebilling.PortalSession, error) { | |
| 155 | + return stripebilling.PortalSession{}, errors.New("unexpected CreatePortalSession call") | |
| 156 | +} | |
| 157 | + | |
| 158 | +func (f *fakeSeatSyncStripeRemote) UpdateSubscriptionItemQuantity(ctx context.Context, in stripebilling.SeatQuantityInput) error { | |
| 159 | + if f.updateQuantityFn == nil { | |
| 160 | + return nil | |
| 161 | + } | |
| 162 | + return f.updateQuantityFn(ctx, in) | |
| 163 | +} | |
| 164 | + | |
| 165 | +func (f *fakeSeatSyncStripeRemote) VerifyWebhook([]byte, string) (stripeapi.Event, error) { | |
| 166 | + return stripeapi.Event{}, errors.New("unexpected VerifyWebhook call") | |
| 167 | +} | |
internal/worker/types.gomodified@@ -87,6 +87,13 @@ const ( | ||
| 87 | 87 | KindOrgGitHubImportRepo Kind = "org:github_import_repo" |
| 88 | 88 | ) |
| 89 | 89 | |
| 90 | +// Organization billing kinds. seat_sync recomputes active org members, | |
| 91 | +// records a local billing snapshot, and updates Stripe subscription-item | |
| 92 | +// quantity when hosted Team billing is active. | |
| 93 | +const ( | |
| 94 | + KindOrgBillingSeatSync Kind = "org:billing_seat_sync" | |
| 95 | +) | |
| 96 | + | |
| 90 | 97 | // NotifyChannel is the Postgres LISTEN/NOTIFY channel the pool subscribes |
| 91 | 98 | // to so it wakes up immediately when a job is enqueued, instead of |
| 92 | 99 | // polling. Callers wrapping enqueue in a tx must NOTIFY inside the |