| 1 | // SPDX-License-Identifier: AGPL-3.0-or-later |
| 2 | |
| 3 | package orgs |
| 4 | |
| 5 | import ( |
| 6 | "context" |
| 7 | "errors" |
| 8 | "fmt" |
| 9 | "strings" |
| 10 | "time" |
| 11 | |
| 12 | "github.com/jackc/pgx/v5" |
| 13 | "github.com/jackc/pgx/v5/pgtype" |
| 14 | |
| 15 | "github.com/tenseleyFlow/shithub/internal/auth/email" |
| 16 | "github.com/tenseleyFlow/shithub/internal/auth/token" |
| 17 | orgsdb "github.com/tenseleyFlow/shithub/internal/orgs/sqlc" |
| 18 | usersdb "github.com/tenseleyFlow/shithub/internal/users/sqlc" |
| 19 | ) |
| 20 | |
| 21 | // InviteParams describes a new invitation. Exactly one of |
| 22 | // TargetUsername / TargetEmail must be non-empty; the resolver below |
| 23 | // converts a username to a user_id via the users table. |
| 24 | type InviteParams struct { |
| 25 | OrgID int64 |
| 26 | InvitedByUserID int64 |
| 27 | TargetUsername string |
| 28 | TargetEmail string |
| 29 | Role string // "owner" | "member" |
| 30 | } |
| 31 | |
| 32 | // InviteResult bundles the freshly-created invitation row + the |
| 33 | // plaintext token. Callers typically render the token into a URL and |
| 34 | // drop it into an email; the token IS the credential — never persist |
| 35 | // it after the email send. |
| 36 | type InviteResult struct { |
| 37 | Invitation orgsdb.OrgInvitation |
| 38 | Token string |
| 39 | } |
| 40 | |
| 41 | // Invite creates a pending invitation. Resolves a username to a |
| 42 | // user_id when provided; otherwise treats the input as an email |
| 43 | // address (recipients without an account claim it on signup). |
| 44 | func Invite(ctx context.Context, deps Deps, p InviteParams) (InviteResult, error) { |
| 45 | if p.OrgID == 0 || p.InvitedByUserID == 0 { |
| 46 | return InviteResult{}, errors.New("orgs: OrgID + InvitedByUserID required") |
| 47 | } |
| 48 | if (p.TargetUsername == "" && p.TargetEmail == "") || |
| 49 | (p.TargetUsername != "" && p.TargetEmail != "") { |
| 50 | return InviteResult{}, ErrInvalidInvitationKind |
| 51 | } |
| 52 | role, err := parseRole(p.Role) |
| 53 | if err != nil { |
| 54 | return InviteResult{}, err |
| 55 | } |
| 56 | |
| 57 | q := orgsdb.New() |
| 58 | |
| 59 | var ( |
| 60 | targetUserID pgtype.Int8 |
| 61 | targetEmail pgtype.Text |
| 62 | ) |
| 63 | if p.TargetUsername != "" { |
| 64 | uname := strings.ToLower(strings.TrimSpace(p.TargetUsername)) |
| 65 | u, err := usersdb.New().GetUserByUsername(ctx, deps.Pool, uname) |
| 66 | if err != nil { |
| 67 | if errors.Is(err, pgx.ErrNoRows) { |
| 68 | return InviteResult{}, ErrUserNotFound |
| 69 | } |
| 70 | return InviteResult{}, err |
| 71 | } |
| 72 | // Already a member? short-circuit. |
| 73 | if _, err := q.GetOrgMember(ctx, deps.Pool, orgsdb.GetOrgMemberParams{ |
| 74 | OrgID: p.OrgID, UserID: u.ID, |
| 75 | }); err == nil { |
| 76 | return InviteResult{}, ErrAlreadyMember |
| 77 | } |
| 78 | targetUserID = pgtype.Int8{Int64: u.ID, Valid: true} |
| 79 | } else { |
| 80 | email := strings.ToLower(strings.TrimSpace(p.TargetEmail)) |
| 81 | if email == "" { |
| 82 | return InviteResult{}, ErrInvalidInvitationKind |
| 83 | } |
| 84 | targetEmail = pgtype.Text{String: email, Valid: true} |
| 85 | } |
| 86 | |
| 87 | // Idempotency: if a pending invite for this target already exists, |
| 88 | // surface it rather than minting a fresh one (avoids token churn |
| 89 | // + accidental spam from re-clicked Invite buttons). |
| 90 | if _, err := q.GetExistingPendingInvitation(ctx, deps.Pool, orgsdb.GetExistingPendingInvitationParams{ |
| 91 | OrgID: p.OrgID, |
| 92 | TargetUserID: targetUserID, |
| 93 | TargetEmail: emailToCitext(targetEmail), |
| 94 | }); err == nil { |
| 95 | return InviteResult{}, ErrInvitationDuplicate |
| 96 | } else if !errors.Is(err, pgx.ErrNoRows) { |
| 97 | return InviteResult{}, err |
| 98 | } |
| 99 | |
| 100 | tokEnc, tokHash, err := token.New() |
| 101 | if err != nil { |
| 102 | return InviteResult{}, fmt.Errorf("invite token: %w", err) |
| 103 | } |
| 104 | row, err := q.CreateOrgInvitation(ctx, deps.Pool, orgsdb.CreateOrgInvitationParams{ |
| 105 | OrgID: p.OrgID, |
| 106 | InvitedByUserID: pgtype.Int8{Int64: p.InvitedByUserID, Valid: true}, |
| 107 | TargetUserID: targetUserID, |
| 108 | TargetEmail: emailToCitext(targetEmail), |
| 109 | Role: role, |
| 110 | TokenHash: tokHash, |
| 111 | ExpiresAt: pgtype.Timestamptz{Time: time.Now().Add(7 * 24 * time.Hour), Valid: true}, |
| 112 | }) |
| 113 | if err != nil { |
| 114 | return InviteResult{}, fmt.Errorf("create invitation: %w", err) |
| 115 | } |
| 116 | |
| 117 | // Best-effort email — failures don't break the invitation row. |
| 118 | go h.tryEmailInvite(deps, row, tokEnc) //nolint:gocritic // closure is fine |
| 119 | |
| 120 | return InviteResult{Invitation: row, Token: tokEnc}, nil |
| 121 | } |
| 122 | |
| 123 | // h is a tiny package-private helper struct so the closure in Invite |
| 124 | // can pin the email-side code without dragging extra fields into |
| 125 | // every call site. Keeps Invite signature focused. |
| 126 | var h struct { |
| 127 | tryEmailInvite func(deps Deps, row orgsdb.OrgInvitation, tokEnc string) |
| 128 | } |
| 129 | |
| 130 | func init() { |
| 131 | h.tryEmailInvite = func(deps Deps, row orgsdb.OrgInvitation, tokEnc string) { |
| 132 | if deps.EmailSender == nil { |
| 133 | return |
| 134 | } |
| 135 | var to string |
| 136 | if row.TargetEmail.Valid { |
| 137 | to = string(row.TargetEmail.String) |
| 138 | } else if row.TargetUserID.Valid { |
| 139 | u, err := usersdb.New().GetUserByID(context.Background(), deps.Pool, row.TargetUserID.Int64) |
| 140 | if err != nil || !u.PrimaryEmailID.Valid { |
| 141 | return |
| 142 | } |
| 143 | em, err := usersdb.New().GetUserEmailByID(context.Background(), deps.Pool, u.PrimaryEmailID.Int64) |
| 144 | if err != nil || !em.Verified { |
| 145 | return |
| 146 | } |
| 147 | to = string(em.Email) |
| 148 | } |
| 149 | if to == "" { |
| 150 | return |
| 151 | } |
| 152 | url := strings.TrimRight(deps.BaseURL, "/") + "/invitations/" + tokEnc |
| 153 | text := "You've been invited to join an organization on " + deps.SiteName + ".\n\n" + |
| 154 | "Accept or decline: " + url + "\n\n" + |
| 155 | "This link expires in 7 days.\n" |
| 156 | _ = deps.EmailSender.Send(context.Background(), email.Message{ |
| 157 | From: deps.EmailFrom, |
| 158 | To: to, |
| 159 | Subject: "[" + deps.SiteName + "] You've been invited to an organization", |
| 160 | Text: text, |
| 161 | HTML: "<p>You've been invited to join an organization on " + deps.SiteName + ".</p><p><a href=\"" + url + "\">Accept or decline</a></p><p>This link expires in 7 days.</p>", |
| 162 | }) |
| 163 | } |
| 164 | } |
| 165 | |
| 166 | // AcceptInvitation marks an invitation accepted and adds the target |
| 167 | // as a member. The acceptor is the currently-logged-in user; for |
| 168 | // email-based invites the acceptor's verified email must match the |
| 169 | // invite's target_email. |
| 170 | func AcceptInvitation(ctx context.Context, deps Deps, inv orgsdb.OrgInvitation, acceptorUserID int64) error { |
| 171 | if err := validatePending(inv); err != nil { |
| 172 | return err |
| 173 | } |
| 174 | // Email-target invites: claim only when the acceptor owns the |
| 175 | // email (verified primary). Username-target invites: only the |
| 176 | // matching user can accept. |
| 177 | if inv.TargetUserID.Valid && inv.TargetUserID.Int64 != acceptorUserID { |
| 178 | return ErrUnauthorizedAcceptor |
| 179 | } |
| 180 | if inv.TargetEmail.Valid { |
| 181 | ok, err := userOwnsVerifiedEmail(ctx, deps, acceptorUserID, string(inv.TargetEmail.String)) |
| 182 | if err != nil { |
| 183 | return err |
| 184 | } |
| 185 | if !ok { |
| 186 | return ErrUnauthorizedAcceptor |
| 187 | } |
| 188 | } |
| 189 | |
| 190 | tx, err := deps.Pool.Begin(ctx) |
| 191 | if err != nil { |
| 192 | return err |
| 193 | } |
| 194 | committed := false |
| 195 | defer func() { |
| 196 | if !committed { |
| 197 | _ = tx.Rollback(ctx) |
| 198 | } |
| 199 | }() |
| 200 | q := orgsdb.New() |
| 201 | if err := q.AddOrgMember(ctx, tx, orgsdb.AddOrgMemberParams{ |
| 202 | OrgID: inv.OrgID, |
| 203 | UserID: acceptorUserID, |
| 204 | Role: inv.Role, |
| 205 | InvitedByUserID: inv.InvitedByUserID, |
| 206 | }); err != nil { |
| 207 | return fmt.Errorf("add member: %w", err) |
| 208 | } |
| 209 | if err := q.AcceptOrgInvitation(ctx, tx, inv.ID); err != nil { |
| 210 | return fmt.Errorf("mark accepted: %w", err) |
| 211 | } |
| 212 | if err := enqueueBillingSeatSync(ctx, tx, deps, inv.OrgID); err != nil { |
| 213 | return fmt.Errorf("enqueue billing seat sync: %w", err) |
| 214 | } |
| 215 | if err := tx.Commit(ctx); err != nil { |
| 216 | return err |
| 217 | } |
| 218 | committed = true |
| 219 | return nil |
| 220 | } |
| 221 | |
| 222 | // DeclineInvitation marks an invitation declined. Same target-binding |
| 223 | // rules as AcceptInvitation apply. |
| 224 | func DeclineInvitation(ctx context.Context, deps Deps, inv orgsdb.OrgInvitation, declinerUserID int64) error { |
| 225 | if err := validatePending(inv); err != nil { |
| 226 | return err |
| 227 | } |
| 228 | if inv.TargetUserID.Valid && inv.TargetUserID.Int64 != declinerUserID { |
| 229 | return ErrUnauthorizedAcceptor |
| 230 | } |
| 231 | if inv.TargetEmail.Valid { |
| 232 | ok, err := userOwnsVerifiedEmail(ctx, deps, declinerUserID, string(inv.TargetEmail.String)) |
| 233 | if err != nil { |
| 234 | return err |
| 235 | } |
| 236 | if !ok { |
| 237 | return ErrUnauthorizedAcceptor |
| 238 | } |
| 239 | } |
| 240 | return orgsdb.New().DeclineOrgInvitation(ctx, deps.Pool, inv.ID) |
| 241 | } |
| 242 | |
| 243 | // CancelInvitation marks an invitation canceled. Caller is the org |
| 244 | // owner / inviter; policy is checked before this is invoked. |
| 245 | func CancelInvitation(ctx context.Context, deps Deps, inv orgsdb.OrgInvitation) error { |
| 246 | if err := validatePending(inv); err != nil { |
| 247 | return err |
| 248 | } |
| 249 | return orgsdb.New().CancelOrgInvitation(ctx, deps.Pool, inv.ID) |
| 250 | } |
| 251 | |
| 252 | // LookupInvitationByToken fetches the invitation matching the |
| 253 | // supplied bearer token. Returns ErrInvitationNotFound when the token |
| 254 | // doesn't match a row. |
| 255 | func LookupInvitationByToken(ctx context.Context, deps Deps, encodedToken string) (orgsdb.OrgInvitation, error) { |
| 256 | hash, err := token.HashOf(encodedToken) |
| 257 | if err != nil { |
| 258 | return orgsdb.OrgInvitation{}, ErrInvitationNotFound |
| 259 | } |
| 260 | row, err := orgsdb.New().GetOrgInvitationByTokenHash(ctx, deps.Pool, hash) |
| 261 | if err != nil { |
| 262 | if errors.Is(err, pgx.ErrNoRows) { |
| 263 | return orgsdb.OrgInvitation{}, ErrInvitationNotFound |
| 264 | } |
| 265 | return orgsdb.OrgInvitation{}, err |
| 266 | } |
| 267 | return row, nil |
| 268 | } |
| 269 | |
| 270 | // ─── helpers ─────────────────────────────────────────────────────── |
| 271 | |
| 272 | // ErrUnauthorizedAcceptor is returned when an acceptor doesn't own |
| 273 | // the email the invite was issued to (or doesn't match the invited |
| 274 | // user). |
| 275 | var ErrUnauthorizedAcceptor = errors.New("orgs: invitation does not match this user") |
| 276 | |
| 277 | func validatePending(inv orgsdb.OrgInvitation) error { |
| 278 | if inv.AcceptedAt.Valid || inv.DeclinedAt.Valid || inv.CanceledAt.Valid { |
| 279 | return ErrInvitationConsumed |
| 280 | } |
| 281 | if !inv.ExpiresAt.Valid || inv.ExpiresAt.Time.Before(time.Now()) { |
| 282 | return ErrInvitationExpired |
| 283 | } |
| 284 | return nil |
| 285 | } |
| 286 | |
| 287 | func userOwnsVerifiedEmail(ctx context.Context, deps Deps, userID int64, email string) (bool, error) { |
| 288 | rows, err := usersdb.New().ListUserEmailsForUser(ctx, deps.Pool, userID) |
| 289 | if err != nil { |
| 290 | return false, err |
| 291 | } |
| 292 | low := strings.ToLower(email) |
| 293 | for _, e := range rows { |
| 294 | if strings.ToLower(string(e.Email)) == low && e.Verified { |
| 295 | return true, nil |
| 296 | } |
| 297 | } |
| 298 | return false, nil |
| 299 | } |
| 300 | |
| 301 | // emailToCitext is a tiny shim because the sqlc-generated email |
| 302 | // column type for `target_email` is a non-pointer pgtype.Text-flavored |
| 303 | // citext. Pass-through when valid; zero value otherwise. |
| 304 | func emailToCitext(p pgtype.Text) pgtype.Text { |
| 305 | if !p.Valid { |
| 306 | return pgtype.Text{Valid: false} |
| 307 | } |
| 308 | return p |
| 309 | } |
| 310 |