Go · 11000 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 package orgs
4
5 import (
6 "context"
7 "errors"
8 "fmt"
9 "strings"
10 "time"
11
12 "github.com/jackc/pgx/v5"
13 "github.com/jackc/pgx/v5/pgtype"
14
15 "github.com/tenseleyFlow/shithub/internal/auth/email"
16 "github.com/tenseleyFlow/shithub/internal/auth/token"
17 "github.com/tenseleyFlow/shithub/internal/entitlements"
18 orgsdb "github.com/tenseleyFlow/shithub/internal/orgs/sqlc"
19 usersdb "github.com/tenseleyFlow/shithub/internal/users/sqlc"
20 )
21
22 // InviteParams describes a new invitation. Exactly one of
23 // TargetUsername / TargetEmail must be non-empty; the resolver below
24 // converts a username to a user_id via the users table.
25 type InviteParams struct {
26 OrgID int64
27 InvitedByUserID int64
28 TargetUsername string
29 TargetEmail string
30 Role string // "owner" | "member"
31 }
32
33 // InviteResult bundles the freshly-created invitation row + the
34 // plaintext token. Callers typically render the token into a URL and
35 // drop it into an email; the token IS the credential — never persist
36 // it after the email send.
37 type InviteResult struct {
38 Invitation orgsdb.OrgInvitation
39 Token string
40 }
41
42 // Invite creates a pending invitation. Resolves a username to a
43 // user_id when provided; otherwise treats the input as an email
44 // address (recipients without an account claim it on signup).
45 func Invite(ctx context.Context, deps Deps, p InviteParams) (InviteResult, error) {
46 if p.OrgID == 0 || p.InvitedByUserID == 0 {
47 return InviteResult{}, errors.New("orgs: OrgID + InvitedByUserID required")
48 }
49 if (p.TargetUsername == "" && p.TargetEmail == "") ||
50 (p.TargetUsername != "" && p.TargetEmail != "") {
51 return InviteResult{}, ErrInvalidInvitationKind
52 }
53 role, err := parseRole(p.Role)
54 if err != nil {
55 return InviteResult{}, err
56 }
57
58 q := orgsdb.New()
59
60 var (
61 targetUserID pgtype.Int8
62 targetEmail pgtype.Text
63 )
64 if p.TargetUsername != "" {
65 uname := strings.ToLower(strings.TrimSpace(p.TargetUsername))
66 u, err := usersdb.New().GetUserByUsername(ctx, deps.Pool, uname)
67 if err != nil {
68 if errors.Is(err, pgx.ErrNoRows) {
69 return InviteResult{}, ErrUserNotFound
70 }
71 return InviteResult{}, err
72 }
73 // Already a member? short-circuit.
74 if _, err := q.GetOrgMember(ctx, deps.Pool, orgsdb.GetOrgMemberParams{
75 OrgID: p.OrgID, UserID: u.ID,
76 }); err == nil {
77 return InviteResult{}, ErrAlreadyMember
78 }
79 targetUserID = pgtype.Int8{Int64: u.ID, Valid: true}
80 } else {
81 email := strings.ToLower(strings.TrimSpace(p.TargetEmail))
82 if email == "" {
83 return InviteResult{}, ErrInvalidInvitationKind
84 }
85 targetEmail = pgtype.Text{String: email, Valid: true}
86 }
87
88 // Idempotency: if a pending invite for this target already exists,
89 // surface it rather than minting a fresh one (avoids token churn
90 // + accidental spam from re-clicked Invite buttons).
91 if _, err := q.GetExistingPendingInvitation(ctx, deps.Pool, orgsdb.GetExistingPendingInvitationParams{
92 OrgID: p.OrgID,
93 TargetUserID: targetUserID,
94 TargetEmail: emailToCitext(targetEmail),
95 }); err == nil {
96 return InviteResult{}, ErrInvitationDuplicate
97 } else if !errors.Is(err, pgx.ErrNoRows) {
98 return InviteResult{}, err
99 }
100 if role == orgsdb.OrgRoleOwner {
101 var check entitlements.PrivateCollaborationCheck
102 if targetUserID.Valid {
103 check, err = entitlements.CheckOrgOwnerPrivateCollaboration(ctx, entitlements.Deps{Pool: deps.Pool}, p.OrgID, targetUserID.Int64)
104 } else {
105 check, err = entitlements.CheckPrivateInvitationSlot(ctx, entitlements.Deps{Pool: deps.Pool}, p.OrgID)
106 }
107 if err != nil {
108 return InviteResult{}, err
109 }
110 if err := check.Err(); err != nil {
111 return InviteResult{}, err
112 }
113 }
114
115 tokEnc, tokHash, err := token.New()
116 if err != nil {
117 return InviteResult{}, fmt.Errorf("invite token: %w", err)
118 }
119 row, err := q.CreateOrgInvitation(ctx, deps.Pool, orgsdb.CreateOrgInvitationParams{
120 OrgID: p.OrgID,
121 InvitedByUserID: pgtype.Int8{Int64: p.InvitedByUserID, Valid: true},
122 TargetUserID: targetUserID,
123 TargetEmail: emailToCitext(targetEmail),
124 Role: role,
125 TokenHash: tokHash,
126 ExpiresAt: pgtype.Timestamptz{Time: time.Now().Add(7 * 24 * time.Hour), Valid: true},
127 })
128 if err != nil {
129 return InviteResult{}, fmt.Errorf("create invitation: %w", err)
130 }
131
132 // Best-effort email — failures don't break the invitation row.
133 go h.tryEmailInvite(deps, row, tokEnc) //nolint:gocritic // closure is fine
134
135 return InviteResult{Invitation: row, Token: tokEnc}, nil
136 }
137
138 // h is a tiny package-private helper struct so the closure in Invite
139 // can pin the email-side code without dragging extra fields into
140 // every call site. Keeps Invite signature focused.
141 var h struct {
142 tryEmailInvite func(deps Deps, row orgsdb.OrgInvitation, tokEnc string)
143 }
144
145 func init() {
146 h.tryEmailInvite = func(deps Deps, row orgsdb.OrgInvitation, tokEnc string) {
147 if deps.EmailSender == nil {
148 return
149 }
150 var to string
151 if row.TargetEmail.Valid {
152 to = string(row.TargetEmail.String)
153 } else if row.TargetUserID.Valid {
154 u, err := usersdb.New().GetUserByID(context.Background(), deps.Pool, row.TargetUserID.Int64)
155 if err != nil || !u.PrimaryEmailID.Valid {
156 return
157 }
158 em, err := usersdb.New().GetUserEmailByID(context.Background(), deps.Pool, u.PrimaryEmailID.Int64)
159 if err != nil || !em.Verified {
160 return
161 }
162 to = string(em.Email)
163 }
164 if to == "" {
165 return
166 }
167 url := strings.TrimRight(deps.BaseURL, "/") + "/invitations/" + tokEnc
168 text := "You've been invited to join an organization on " + deps.SiteName + ".\n\n" +
169 "Accept or decline: " + url + "\n\n" +
170 "This link expires in 7 days.\n"
171 _ = deps.EmailSender.Send(context.Background(), email.Message{
172 From: deps.EmailFrom,
173 To: to,
174 Subject: "[" + deps.SiteName + "] You've been invited to an organization",
175 Text: text,
176 HTML: "<p>You've been invited to join an organization on " + deps.SiteName + ".</p><p><a href=\"" + url + "\">Accept or decline</a></p><p>This link expires in 7 days.</p>",
177 })
178 }
179 }
180
181 // AcceptInvitation marks an invitation accepted and adds the target
182 // as a member. The acceptor is the currently-logged-in user; for
183 // email-based invites the acceptor's verified email must match the
184 // invite's target_email.
185 func AcceptInvitation(ctx context.Context, deps Deps, inv orgsdb.OrgInvitation, acceptorUserID int64) error {
186 if err := validatePending(inv); err != nil {
187 return err
188 }
189 // Email-target invites: claim only when the acceptor owns the
190 // email (verified primary). Username-target invites: only the
191 // matching user can accept.
192 if inv.TargetUserID.Valid && inv.TargetUserID.Int64 != acceptorUserID {
193 return ErrUnauthorizedAcceptor
194 }
195 if inv.TargetEmail.Valid {
196 ok, err := userOwnsVerifiedEmail(ctx, deps, acceptorUserID, string(inv.TargetEmail.String))
197 if err != nil {
198 return err
199 }
200 if !ok {
201 return ErrUnauthorizedAcceptor
202 }
203 }
204 if inv.Role == orgsdb.OrgRoleOwner {
205 check, err := entitlements.CheckOrgOwnerPrivateCollaboration(ctx, entitlements.Deps{Pool: deps.Pool}, inv.OrgID, acceptorUserID)
206 if err != nil {
207 return err
208 }
209 if err := check.Err(); err != nil {
210 return err
211 }
212 }
213
214 tx, err := deps.Pool.Begin(ctx)
215 if err != nil {
216 return err
217 }
218 committed := false
219 defer func() {
220 if !committed {
221 _ = tx.Rollback(ctx)
222 }
223 }()
224 q := orgsdb.New()
225 if err := q.AddOrgMember(ctx, tx, orgsdb.AddOrgMemberParams{
226 OrgID: inv.OrgID,
227 UserID: acceptorUserID,
228 Role: inv.Role,
229 InvitedByUserID: inv.InvitedByUserID,
230 }); err != nil {
231 return fmt.Errorf("add member: %w", err)
232 }
233 if err := q.AcceptOrgInvitation(ctx, tx, inv.ID); err != nil {
234 return fmt.Errorf("mark accepted: %w", err)
235 }
236 if err := enqueueBillingSeatSync(ctx, tx, deps, inv.OrgID); err != nil {
237 return fmt.Errorf("enqueue billing seat sync: %w", err)
238 }
239 if err := tx.Commit(ctx); err != nil {
240 return err
241 }
242 committed = true
243 return nil
244 }
245
246 // DeclineInvitation marks an invitation declined. Same target-binding
247 // rules as AcceptInvitation apply.
248 func DeclineInvitation(ctx context.Context, deps Deps, inv orgsdb.OrgInvitation, declinerUserID int64) error {
249 if err := validatePending(inv); err != nil {
250 return err
251 }
252 if inv.TargetUserID.Valid && inv.TargetUserID.Int64 != declinerUserID {
253 return ErrUnauthorizedAcceptor
254 }
255 if inv.TargetEmail.Valid {
256 ok, err := userOwnsVerifiedEmail(ctx, deps, declinerUserID, string(inv.TargetEmail.String))
257 if err != nil {
258 return err
259 }
260 if !ok {
261 return ErrUnauthorizedAcceptor
262 }
263 }
264 return orgsdb.New().DeclineOrgInvitation(ctx, deps.Pool, inv.ID)
265 }
266
267 // CancelInvitation marks an invitation canceled. Caller is the org
268 // owner / inviter; policy is checked before this is invoked.
269 func CancelInvitation(ctx context.Context, deps Deps, inv orgsdb.OrgInvitation) error {
270 if err := validatePending(inv); err != nil {
271 return err
272 }
273 return orgsdb.New().CancelOrgInvitation(ctx, deps.Pool, inv.ID)
274 }
275
276 // LookupInvitationByToken fetches the invitation matching the
277 // supplied bearer token. Returns ErrInvitationNotFound when the token
278 // doesn't match a row.
279 func LookupInvitationByToken(ctx context.Context, deps Deps, encodedToken string) (orgsdb.OrgInvitation, error) {
280 hash, err := token.HashOf(encodedToken)
281 if err != nil {
282 return orgsdb.OrgInvitation{}, ErrInvitationNotFound
283 }
284 row, err := orgsdb.New().GetOrgInvitationByTokenHash(ctx, deps.Pool, hash)
285 if err != nil {
286 if errors.Is(err, pgx.ErrNoRows) {
287 return orgsdb.OrgInvitation{}, ErrInvitationNotFound
288 }
289 return orgsdb.OrgInvitation{}, err
290 }
291 return row, nil
292 }
293
294 // ─── helpers ───────────────────────────────────────────────────────
295
296 // ErrUnauthorizedAcceptor is returned when an acceptor doesn't own
297 // the email the invite was issued to (or doesn't match the invited
298 // user).
299 var ErrUnauthorizedAcceptor = errors.New("orgs: invitation does not match this user")
300
301 func validatePending(inv orgsdb.OrgInvitation) error {
302 if inv.AcceptedAt.Valid || inv.DeclinedAt.Valid || inv.CanceledAt.Valid {
303 return ErrInvitationConsumed
304 }
305 if !inv.ExpiresAt.Valid || inv.ExpiresAt.Time.Before(time.Now()) {
306 return ErrInvitationExpired
307 }
308 return nil
309 }
310
311 func userOwnsVerifiedEmail(ctx context.Context, deps Deps, userID int64, email string) (bool, error) {
312 rows, err := usersdb.New().ListUserEmailsForUser(ctx, deps.Pool, userID)
313 if err != nil {
314 return false, err
315 }
316 low := strings.ToLower(email)
317 for _, e := range rows {
318 if strings.ToLower(string(e.Email)) == low && e.Verified {
319 return true, nil
320 }
321 }
322 return false, nil
323 }
324
325 // emailToCitext is a tiny shim because the sqlc-generated email
326 // column type for `target_email` is a non-pointer pgtype.Text-flavored
327 // citext. Pass-through when valid; zero value otherwise.
328 func emailToCitext(p pgtype.Text) pgtype.Text {
329 if !p.Valid {
330 return pgtype.Text{Valid: false}
331 }
332 return p
333 }
334