Go · 5762 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 package jobs_test
4
5 import (
6 "context"
7 "encoding/json"
8 "errors"
9 "io"
10 "log/slog"
11 "testing"
12 "time"
13
14 "github.com/jackc/pgx/v5/pgxpool"
15 stripeapi "github.com/stripe/stripe-go/v85"
16
17 "github.com/tenseleyFlow/shithub/internal/billing"
18 "github.com/tenseleyFlow/shithub/internal/billing/stripebilling"
19 "github.com/tenseleyFlow/shithub/internal/orgs"
20 "github.com/tenseleyFlow/shithub/internal/testing/dbtest"
21 usersdb "github.com/tenseleyFlow/shithub/internal/users/sqlc"
22 "github.com/tenseleyFlow/shithub/internal/worker/jobs"
23 )
24
25 const billingFixtureHash = "$argon2id$v=19$m=16384,t=1,p=1$" +
26 "AAAAAAAAAAAAAAAA$" +
27 "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"
28
29 func TestOrgBillingSeatSyncUpdatesStateAndStripeQuantity(t *testing.T) {
30 t.Parallel()
31 ctx := context.Background()
32 pool, orgID := setupOrgBillingSeatSync(t)
33 memberID := createBillingUser(t, pool, "bob")
34 if err := orgs.AddMember(ctx, orgs.Deps{Pool: pool, Logger: discardLogger()}, orgID, memberID, 0, "member"); err != nil {
35 t.Fatalf("AddMember: %v", err)
36 }
37 if _, err := billing.SetStripeCustomer(ctx, billing.Deps{Pool: pool}, orgID, "cus_test"); err != nil {
38 t.Fatalf("SetStripeCustomer: %v", err)
39 }
40 start := time.Now().UTC().Truncate(time.Second)
41 if _, err := billing.ApplySubscriptionSnapshot(ctx, billing.Deps{Pool: pool}, billing.SubscriptionSnapshot{
42 OrgID: orgID,
43 Plan: billing.PlanTeam,
44 Status: billing.SubscriptionStatusActive,
45 StripeSubscriptionID: "sub_test",
46 StripeSubscriptionItemID: "si_test",
47 CurrentPeriodStart: start,
48 CurrentPeriodEnd: start.Add(30 * 24 * time.Hour),
49 LastWebhookEventID: "evt_test",
50 }); err != nil {
51 t.Fatalf("ApplySubscriptionSnapshot: %v", err)
52 }
53
54 var got stripebilling.SeatQuantityInput
55 handler := jobs.OrgBillingSeatSync(jobs.OrgBillingSeatSyncDeps{
56 Pool: pool,
57 Logger: discardLogger(),
58 Stripe: &fakeSeatSyncStripeRemote{
59 updateQuantityFn: func(_ context.Context, in stripebilling.SeatQuantityInput) error {
60 got = in
61 return nil
62 },
63 },
64 })
65 payload, _ := json.Marshal(jobs.OrgBillingSeatSyncPayload{OrgID: orgID})
66 if err := handler(ctx, payload); err != nil {
67 t.Fatalf("OrgBillingSeatSync: %v", err)
68 }
69
70 state, err := billing.GetOrgBillingState(ctx, billing.Deps{Pool: pool}, orgID)
71 if err != nil {
72 t.Fatalf("GetOrgBillingState: %v", err)
73 }
74 if state.BillableSeats != 2 || !state.SeatSnapshotAt.Valid {
75 t.Fatalf("seat snapshot not reflected in state: %+v", state)
76 }
77 if got.OrgID != orgID || got.SubscriptionItemID != "si_test" || got.Quantity != 2 {
78 t.Fatalf("unexpected stripe quantity update: %+v", got)
79 }
80 }
81
82 func TestOrgBillingSeatSyncSkipsStripeForFreeOrg(t *testing.T) {
83 t.Parallel()
84 ctx := context.Background()
85 pool, orgID := setupOrgBillingSeatSync(t)
86 called := false
87 handler := jobs.OrgBillingSeatSync(jobs.OrgBillingSeatSyncDeps{
88 Pool: pool,
89 Logger: discardLogger(),
90 Stripe: &fakeSeatSyncStripeRemote{
91 updateQuantityFn: func(_ context.Context, _ stripebilling.SeatQuantityInput) error {
92 called = true
93 return nil
94 },
95 },
96 })
97 payload, _ := json.Marshal(jobs.OrgBillingSeatSyncPayload{OrgID: orgID})
98 if err := handler(ctx, payload); err != nil {
99 t.Fatalf("OrgBillingSeatSync: %v", err)
100 }
101 if called {
102 t.Fatal("expected free org seat sync to skip Stripe quantity update")
103 }
104 state, err := billing.GetOrgBillingState(ctx, billing.Deps{Pool: pool}, orgID)
105 if err != nil {
106 t.Fatalf("GetOrgBillingState: %v", err)
107 }
108 if state.BillableSeats != 1 || !state.SeatSnapshotAt.Valid {
109 t.Fatalf("free org seat snapshot not recorded: %+v", state)
110 }
111 }
112
113 func setupOrgBillingSeatSync(t *testing.T) (*pgxpool.Pool, int64) {
114 t.Helper()
115 pool := dbtest.NewTestDB(t)
116 ctx := context.Background()
117 ownerID := createBillingUser(t, pool, "owner")
118 org, err := orgs.Create(ctx, orgs.Deps{Pool: pool, Logger: discardLogger()}, orgs.CreateParams{
119 Slug: "acme", DisplayName: "Acme Inc", CreatedByUserID: ownerID,
120 })
121 if err != nil {
122 t.Fatalf("orgs.Create: %v", err)
123 }
124 return pool, org.ID
125 }
126
127 func createBillingUser(t *testing.T, pool *pgxpool.Pool, username string) int64 {
128 t.Helper()
129 user, err := usersdb.New().CreateUser(context.Background(), pool, usersdb.CreateUserParams{
130 Username: username, DisplayName: username, PasswordHash: billingFixtureHash,
131 })
132 if err != nil {
133 t.Fatalf("CreateUser(%s): %v", username, err)
134 }
135 return user.ID
136 }
137
138 func discardLogger() *slog.Logger {
139 return slog.New(slog.NewTextHandler(io.Discard, nil))
140 }
141
142 type fakeSeatSyncStripeRemote struct {
143 updateQuantityFn func(context.Context, stripebilling.SeatQuantityInput) error
144 }
145
146 func (f *fakeSeatSyncStripeRemote) CreateCustomer(context.Context, stripebilling.CustomerInput) (stripebilling.Customer, error) {
147 return stripebilling.Customer{}, errors.New("unexpected CreateCustomer call")
148 }
149
150 func (f *fakeSeatSyncStripeRemote) CreateCheckoutSession(context.Context, stripebilling.CheckoutInput) (stripebilling.CheckoutSession, error) {
151 return stripebilling.CheckoutSession{}, errors.New("unexpected CreateCheckoutSession call")
152 }
153
154 func (f *fakeSeatSyncStripeRemote) CreatePortalSession(context.Context, stripebilling.PortalInput) (stripebilling.PortalSession, error) {
155 return stripebilling.PortalSession{}, errors.New("unexpected CreatePortalSession call")
156 }
157
158 func (f *fakeSeatSyncStripeRemote) UpdateSubscriptionItemQuantity(ctx context.Context, in stripebilling.SeatQuantityInput) error {
159 if f.updateQuantityFn == nil {
160 return nil
161 }
162 return f.updateQuantityFn(ctx, in)
163 }
164
165 func (f *fakeSeatSyncStripeRemote) VerifyWebhook([]byte, string) (stripeapi.Event, error) {
166 return stripeapi.Event{}, errors.New("unexpected VerifyWebhook call")
167 }
168