Go · 10052 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 package orgs
4
5 import (
6 "context"
7 "errors"
8 "fmt"
9 "strings"
10 "time"
11
12 "github.com/jackc/pgx/v5"
13 "github.com/jackc/pgx/v5/pgtype"
14
15 "github.com/tenseleyFlow/shithub/internal/auth/email"
16 "github.com/tenseleyFlow/shithub/internal/auth/token"
17 orgsdb "github.com/tenseleyFlow/shithub/internal/orgs/sqlc"
18 usersdb "github.com/tenseleyFlow/shithub/internal/users/sqlc"
19 )
20
21 // InviteParams describes a new invitation. Exactly one of
22 // TargetUsername / TargetEmail must be non-empty; the resolver below
23 // converts a username to a user_id via the users table.
24 type InviteParams struct {
25 OrgID int64
26 InvitedByUserID int64
27 TargetUsername string
28 TargetEmail string
29 Role string // "owner" | "member"
30 }
31
32 // InviteResult bundles the freshly-created invitation row + the
33 // plaintext token. Callers typically render the token into a URL and
34 // drop it into an email; the token IS the credential — never persist
35 // it after the email send.
36 type InviteResult struct {
37 Invitation orgsdb.OrgInvitation
38 Token string
39 }
40
41 // Invite creates a pending invitation. Resolves a username to a
42 // user_id when provided; otherwise treats the input as an email
43 // address (recipients without an account claim it on signup).
44 func Invite(ctx context.Context, deps Deps, p InviteParams) (InviteResult, error) {
45 if p.OrgID == 0 || p.InvitedByUserID == 0 {
46 return InviteResult{}, errors.New("orgs: OrgID + InvitedByUserID required")
47 }
48 if (p.TargetUsername == "" && p.TargetEmail == "") ||
49 (p.TargetUsername != "" && p.TargetEmail != "") {
50 return InviteResult{}, ErrInvalidInvitationKind
51 }
52 role, err := parseRole(p.Role)
53 if err != nil {
54 return InviteResult{}, err
55 }
56
57 q := orgsdb.New()
58
59 var (
60 targetUserID pgtype.Int8
61 targetEmail pgtype.Text
62 )
63 if p.TargetUsername != "" {
64 uname := strings.ToLower(strings.TrimSpace(p.TargetUsername))
65 u, err := usersdb.New().GetUserByUsername(ctx, deps.Pool, uname)
66 if err != nil {
67 if errors.Is(err, pgx.ErrNoRows) {
68 return InviteResult{}, ErrUserNotFound
69 }
70 return InviteResult{}, err
71 }
72 // Already a member? short-circuit.
73 if _, err := q.GetOrgMember(ctx, deps.Pool, orgsdb.GetOrgMemberParams{
74 OrgID: p.OrgID, UserID: u.ID,
75 }); err == nil {
76 return InviteResult{}, ErrAlreadyMember
77 }
78 targetUserID = pgtype.Int8{Int64: u.ID, Valid: true}
79 } else {
80 email := strings.ToLower(strings.TrimSpace(p.TargetEmail))
81 if email == "" {
82 return InviteResult{}, ErrInvalidInvitationKind
83 }
84 targetEmail = pgtype.Text{String: email, Valid: true}
85 }
86
87 // Idempotency: if a pending invite for this target already exists,
88 // surface it rather than minting a fresh one (avoids token churn
89 // + accidental spam from re-clicked Invite buttons).
90 if _, err := q.GetExistingPendingInvitation(ctx, deps.Pool, orgsdb.GetExistingPendingInvitationParams{
91 OrgID: p.OrgID,
92 TargetUserID: targetUserID,
93 TargetEmail: emailToCitext(targetEmail),
94 }); err == nil {
95 return InviteResult{}, ErrInvitationDuplicate
96 } else if !errors.Is(err, pgx.ErrNoRows) {
97 return InviteResult{}, err
98 }
99
100 tokEnc, tokHash, err := token.New()
101 if err != nil {
102 return InviteResult{}, fmt.Errorf("invite token: %w", err)
103 }
104 row, err := q.CreateOrgInvitation(ctx, deps.Pool, orgsdb.CreateOrgInvitationParams{
105 OrgID: p.OrgID,
106 InvitedByUserID: pgtype.Int8{Int64: p.InvitedByUserID, Valid: true},
107 TargetUserID: targetUserID,
108 TargetEmail: emailToCitext(targetEmail),
109 Role: role,
110 TokenHash: tokHash,
111 ExpiresAt: pgtype.Timestamptz{Time: time.Now().Add(7 * 24 * time.Hour), Valid: true},
112 })
113 if err != nil {
114 return InviteResult{}, fmt.Errorf("create invitation: %w", err)
115 }
116
117 // Best-effort email — failures don't break the invitation row.
118 go h.tryEmailInvite(deps, row, tokEnc) //nolint:gocritic // closure is fine
119
120 return InviteResult{Invitation: row, Token: tokEnc}, nil
121 }
122
123 // h is a tiny package-private helper struct so the closure in Invite
124 // can pin the email-side code without dragging extra fields into
125 // every call site. Keeps Invite signature focused.
126 var h struct {
127 tryEmailInvite func(deps Deps, row orgsdb.OrgInvitation, tokEnc string)
128 }
129
130 func init() {
131 h.tryEmailInvite = func(deps Deps, row orgsdb.OrgInvitation, tokEnc string) {
132 if deps.EmailSender == nil {
133 return
134 }
135 var to string
136 if row.TargetEmail.Valid {
137 to = string(row.TargetEmail.String)
138 } else if row.TargetUserID.Valid {
139 u, err := usersdb.New().GetUserByID(context.Background(), deps.Pool, row.TargetUserID.Int64)
140 if err != nil || !u.PrimaryEmailID.Valid {
141 return
142 }
143 em, err := usersdb.New().GetUserEmailByID(context.Background(), deps.Pool, u.PrimaryEmailID.Int64)
144 if err != nil || !em.Verified {
145 return
146 }
147 to = string(em.Email)
148 }
149 if to == "" {
150 return
151 }
152 url := strings.TrimRight(deps.BaseURL, "/") + "/invitations/" + tokEnc
153 text := "You've been invited to join an organization on " + deps.SiteName + ".\n\n" +
154 "Accept or decline: " + url + "\n\n" +
155 "This link expires in 7 days.\n"
156 _ = deps.EmailSender.Send(context.Background(), email.Message{
157 From: deps.EmailFrom,
158 To: to,
159 Subject: "[" + deps.SiteName + "] You've been invited to an organization",
160 Text: text,
161 HTML: "<p>You've been invited to join an organization on " + deps.SiteName + ".</p><p><a href=\"" + url + "\">Accept or decline</a></p><p>This link expires in 7 days.</p>",
162 })
163 }
164 }
165
166 // AcceptInvitation marks an invitation accepted and adds the target
167 // as a member. The acceptor is the currently-logged-in user; for
168 // email-based invites the acceptor's verified email must match the
169 // invite's target_email.
170 func AcceptInvitation(ctx context.Context, deps Deps, inv orgsdb.OrgInvitation, acceptorUserID int64) error {
171 if err := validatePending(inv); err != nil {
172 return err
173 }
174 // Email-target invites: claim only when the acceptor owns the
175 // email (verified primary). Username-target invites: only the
176 // matching user can accept.
177 if inv.TargetUserID.Valid && inv.TargetUserID.Int64 != acceptorUserID {
178 return ErrUnauthorizedAcceptor
179 }
180 if inv.TargetEmail.Valid {
181 ok, err := userOwnsVerifiedEmail(ctx, deps, acceptorUserID, string(inv.TargetEmail.String))
182 if err != nil {
183 return err
184 }
185 if !ok {
186 return ErrUnauthorizedAcceptor
187 }
188 }
189
190 tx, err := deps.Pool.Begin(ctx)
191 if err != nil {
192 return err
193 }
194 committed := false
195 defer func() {
196 if !committed {
197 _ = tx.Rollback(ctx)
198 }
199 }()
200 q := orgsdb.New()
201 if err := q.AddOrgMember(ctx, tx, orgsdb.AddOrgMemberParams{
202 OrgID: inv.OrgID,
203 UserID: acceptorUserID,
204 Role: inv.Role,
205 InvitedByUserID: inv.InvitedByUserID,
206 }); err != nil {
207 return fmt.Errorf("add member: %w", err)
208 }
209 if err := q.AcceptOrgInvitation(ctx, tx, inv.ID); err != nil {
210 return fmt.Errorf("mark accepted: %w", err)
211 }
212 if err := tx.Commit(ctx); err != nil {
213 return err
214 }
215 committed = true
216 return nil
217 }
218
219 // DeclineInvitation marks an invitation declined. Same target-binding
220 // rules as AcceptInvitation apply.
221 func DeclineInvitation(ctx context.Context, deps Deps, inv orgsdb.OrgInvitation, declinerUserID int64) error {
222 if err := validatePending(inv); err != nil {
223 return err
224 }
225 if inv.TargetUserID.Valid && inv.TargetUserID.Int64 != declinerUserID {
226 return ErrUnauthorizedAcceptor
227 }
228 if inv.TargetEmail.Valid {
229 ok, err := userOwnsVerifiedEmail(ctx, deps, declinerUserID, string(inv.TargetEmail.String))
230 if err != nil {
231 return err
232 }
233 if !ok {
234 return ErrUnauthorizedAcceptor
235 }
236 }
237 return orgsdb.New().DeclineOrgInvitation(ctx, deps.Pool, inv.ID)
238 }
239
240 // CancelInvitation marks an invitation canceled. Caller is the org
241 // owner / inviter; policy is checked before this is invoked.
242 func CancelInvitation(ctx context.Context, deps Deps, inv orgsdb.OrgInvitation) error {
243 if err := validatePending(inv); err != nil {
244 return err
245 }
246 return orgsdb.New().CancelOrgInvitation(ctx, deps.Pool, inv.ID)
247 }
248
249 // LookupInvitationByToken fetches the invitation matching the
250 // supplied bearer token. Returns ErrInvitationNotFound when the token
251 // doesn't match a row.
252 func LookupInvitationByToken(ctx context.Context, deps Deps, encodedToken string) (orgsdb.OrgInvitation, error) {
253 hash, err := token.HashOf(encodedToken)
254 if err != nil {
255 return orgsdb.OrgInvitation{}, ErrInvitationNotFound
256 }
257 row, err := orgsdb.New().GetOrgInvitationByTokenHash(ctx, deps.Pool, hash)
258 if err != nil {
259 if errors.Is(err, pgx.ErrNoRows) {
260 return orgsdb.OrgInvitation{}, ErrInvitationNotFound
261 }
262 return orgsdb.OrgInvitation{}, err
263 }
264 return row, nil
265 }
266
267 // ─── helpers ───────────────────────────────────────────────────────
268
269 // ErrUnauthorizedAcceptor is returned when an acceptor doesn't own
270 // the email the invite was issued to (or doesn't match the invited
271 // user).
272 var ErrUnauthorizedAcceptor = errors.New("orgs: invitation does not match this user")
273
274 func validatePending(inv orgsdb.OrgInvitation) error {
275 if inv.AcceptedAt.Valid || inv.DeclinedAt.Valid || inv.CanceledAt.Valid {
276 return ErrInvitationConsumed
277 }
278 if !inv.ExpiresAt.Valid || inv.ExpiresAt.Time.Before(time.Now()) {
279 return ErrInvitationExpired
280 }
281 return nil
282 }
283
284 func userOwnsVerifiedEmail(ctx context.Context, deps Deps, userID int64, email string) (bool, error) {
285 rows, err := usersdb.New().ListUserEmailsForUser(ctx, deps.Pool, userID)
286 if err != nil {
287 return false, err
288 }
289 low := strings.ToLower(email)
290 for _, e := range rows {
291 if strings.ToLower(string(e.Email)) == low && e.Verified {
292 return true, nil
293 }
294 }
295 return false, nil
296 }
297
298 // emailToCitext is a tiny shim because the sqlc-generated email
299 // column type for `target_email` is a non-pointer pgtype.Text-flavored
300 // citext. Pass-through when valid; zero value otherwise.
301 func emailToCitext(p pgtype.Text) pgtype.Text {
302 if !p.Valid {
303 return pgtype.Text{Valid: false}
304 }
305 return p
306 }
307