| 1 | // SPDX-License-Identifier: AGPL-3.0-or-later |
| 2 | |
| 3 | package orgs |
| 4 | |
| 5 | import ( |
| 6 | "context" |
| 7 | "errors" |
| 8 | "fmt" |
| 9 | "strings" |
| 10 | "time" |
| 11 | |
| 12 | "github.com/jackc/pgx/v5" |
| 13 | "github.com/jackc/pgx/v5/pgtype" |
| 14 | |
| 15 | "github.com/tenseleyFlow/shithub/internal/auth/email" |
| 16 | "github.com/tenseleyFlow/shithub/internal/auth/token" |
| 17 | "github.com/tenseleyFlow/shithub/internal/entitlements" |
| 18 | orgsdb "github.com/tenseleyFlow/shithub/internal/orgs/sqlc" |
| 19 | usersdb "github.com/tenseleyFlow/shithub/internal/users/sqlc" |
| 20 | ) |
| 21 | |
| 22 | // InviteParams describes a new invitation. Exactly one of |
| 23 | // TargetUsername / TargetEmail must be non-empty; the resolver below |
| 24 | // converts a username to a user_id via the users table. |
| 25 | type InviteParams struct { |
| 26 | OrgID int64 |
| 27 | InvitedByUserID int64 |
| 28 | TargetUsername string |
| 29 | TargetEmail string |
| 30 | Role string // "owner" | "member" |
| 31 | } |
| 32 | |
| 33 | // InviteResult bundles the freshly-created invitation row + the |
| 34 | // plaintext token. Callers typically render the token into a URL and |
| 35 | // drop it into an email; the token IS the credential — never persist |
| 36 | // it after the email send. |
| 37 | type InviteResult struct { |
| 38 | Invitation orgsdb.OrgInvitation |
| 39 | Token string |
| 40 | } |
| 41 | |
| 42 | // Invite creates a pending invitation. Resolves a username to a |
| 43 | // user_id when provided; otherwise treats the input as an email |
| 44 | // address (recipients without an account claim it on signup). |
| 45 | func Invite(ctx context.Context, deps Deps, p InviteParams) (InviteResult, error) { |
| 46 | if p.OrgID == 0 || p.InvitedByUserID == 0 { |
| 47 | return InviteResult{}, errors.New("orgs: OrgID + InvitedByUserID required") |
| 48 | } |
| 49 | if (p.TargetUsername == "" && p.TargetEmail == "") || |
| 50 | (p.TargetUsername != "" && p.TargetEmail != "") { |
| 51 | return InviteResult{}, ErrInvalidInvitationKind |
| 52 | } |
| 53 | role, err := parseRole(p.Role) |
| 54 | if err != nil { |
| 55 | return InviteResult{}, err |
| 56 | } |
| 57 | |
| 58 | q := orgsdb.New() |
| 59 | |
| 60 | var ( |
| 61 | targetUserID pgtype.Int8 |
| 62 | targetEmail pgtype.Text |
| 63 | ) |
| 64 | if p.TargetUsername != "" { |
| 65 | uname := strings.ToLower(strings.TrimSpace(p.TargetUsername)) |
| 66 | u, err := usersdb.New().GetUserByUsername(ctx, deps.Pool, uname) |
| 67 | if err != nil { |
| 68 | if errors.Is(err, pgx.ErrNoRows) { |
| 69 | return InviteResult{}, ErrUserNotFound |
| 70 | } |
| 71 | return InviteResult{}, err |
| 72 | } |
| 73 | // Already a member? short-circuit. |
| 74 | if _, err := q.GetOrgMember(ctx, deps.Pool, orgsdb.GetOrgMemberParams{ |
| 75 | OrgID: p.OrgID, UserID: u.ID, |
| 76 | }); err == nil { |
| 77 | return InviteResult{}, ErrAlreadyMember |
| 78 | } |
| 79 | targetUserID = pgtype.Int8{Int64: u.ID, Valid: true} |
| 80 | } else { |
| 81 | email := strings.ToLower(strings.TrimSpace(p.TargetEmail)) |
| 82 | if email == "" { |
| 83 | return InviteResult{}, ErrInvalidInvitationKind |
| 84 | } |
| 85 | targetEmail = pgtype.Text{String: email, Valid: true} |
| 86 | } |
| 87 | |
| 88 | // Idempotency: if a pending invite for this target already exists, |
| 89 | // surface it rather than minting a fresh one (avoids token churn |
| 90 | // + accidental spam from re-clicked Invite buttons). |
| 91 | if _, err := q.GetExistingPendingInvitation(ctx, deps.Pool, orgsdb.GetExistingPendingInvitationParams{ |
| 92 | OrgID: p.OrgID, |
| 93 | TargetUserID: targetUserID, |
| 94 | TargetEmail: emailToCitext(targetEmail), |
| 95 | }); err == nil { |
| 96 | return InviteResult{}, ErrInvitationDuplicate |
| 97 | } else if !errors.Is(err, pgx.ErrNoRows) { |
| 98 | return InviteResult{}, err |
| 99 | } |
| 100 | if role == orgsdb.OrgRoleOwner { |
| 101 | var check entitlements.PrivateCollaborationCheck |
| 102 | if targetUserID.Valid { |
| 103 | check, err = entitlements.CheckOrgOwnerPrivateCollaboration(ctx, entitlements.Deps{Pool: deps.Pool}, p.OrgID, targetUserID.Int64) |
| 104 | } else { |
| 105 | check, err = entitlements.CheckPrivateInvitationSlot(ctx, entitlements.Deps{Pool: deps.Pool}, p.OrgID) |
| 106 | } |
| 107 | if err != nil { |
| 108 | return InviteResult{}, err |
| 109 | } |
| 110 | if err := check.Err(); err != nil { |
| 111 | return InviteResult{}, err |
| 112 | } |
| 113 | } |
| 114 | |
| 115 | tokEnc, tokHash, err := token.New() |
| 116 | if err != nil { |
| 117 | return InviteResult{}, fmt.Errorf("invite token: %w", err) |
| 118 | } |
| 119 | row, err := q.CreateOrgInvitation(ctx, deps.Pool, orgsdb.CreateOrgInvitationParams{ |
| 120 | OrgID: p.OrgID, |
| 121 | InvitedByUserID: pgtype.Int8{Int64: p.InvitedByUserID, Valid: true}, |
| 122 | TargetUserID: targetUserID, |
| 123 | TargetEmail: emailToCitext(targetEmail), |
| 124 | Role: role, |
| 125 | TokenHash: tokHash, |
| 126 | ExpiresAt: pgtype.Timestamptz{Time: time.Now().Add(7 * 24 * time.Hour), Valid: true}, |
| 127 | }) |
| 128 | if err != nil { |
| 129 | return InviteResult{}, fmt.Errorf("create invitation: %w", err) |
| 130 | } |
| 131 | |
| 132 | // Best-effort email — failures don't break the invitation row. |
| 133 | go h.tryEmailInvite(deps, row, tokEnc) //nolint:gocritic // closure is fine |
| 134 | |
| 135 | return InviteResult{Invitation: row, Token: tokEnc}, nil |
| 136 | } |
| 137 | |
| 138 | // h is a tiny package-private helper struct so the closure in Invite |
| 139 | // can pin the email-side code without dragging extra fields into |
| 140 | // every call site. Keeps Invite signature focused. |
| 141 | var h struct { |
| 142 | tryEmailInvite func(deps Deps, row orgsdb.OrgInvitation, tokEnc string) |
| 143 | } |
| 144 | |
| 145 | func init() { |
| 146 | h.tryEmailInvite = func(deps Deps, row orgsdb.OrgInvitation, tokEnc string) { |
| 147 | if deps.EmailSender == nil { |
| 148 | return |
| 149 | } |
| 150 | var to string |
| 151 | if row.TargetEmail.Valid { |
| 152 | to = string(row.TargetEmail.String) |
| 153 | } else if row.TargetUserID.Valid { |
| 154 | u, err := usersdb.New().GetUserByID(context.Background(), deps.Pool, row.TargetUserID.Int64) |
| 155 | if err != nil || !u.PrimaryEmailID.Valid { |
| 156 | return |
| 157 | } |
| 158 | em, err := usersdb.New().GetUserEmailByID(context.Background(), deps.Pool, u.PrimaryEmailID.Int64) |
| 159 | if err != nil || !em.Verified { |
| 160 | return |
| 161 | } |
| 162 | to = string(em.Email) |
| 163 | } |
| 164 | if to == "" { |
| 165 | return |
| 166 | } |
| 167 | url := strings.TrimRight(deps.BaseURL, "/") + "/invitations/" + tokEnc |
| 168 | text := "You've been invited to join an organization on " + deps.SiteName + ".\n\n" + |
| 169 | "Accept or decline: " + url + "\n\n" + |
| 170 | "This link expires in 7 days.\n" |
| 171 | _ = deps.EmailSender.Send(context.Background(), email.Message{ |
| 172 | From: deps.EmailFrom, |
| 173 | To: to, |
| 174 | Subject: "[" + deps.SiteName + "] You've been invited to an organization", |
| 175 | Text: text, |
| 176 | HTML: "<p>You've been invited to join an organization on " + deps.SiteName + ".</p><p><a href=\"" + url + "\">Accept or decline</a></p><p>This link expires in 7 days.</p>", |
| 177 | }) |
| 178 | } |
| 179 | } |
| 180 | |
| 181 | // AcceptInvitation marks an invitation accepted and adds the target |
| 182 | // as a member. The acceptor is the currently-logged-in user; for |
| 183 | // email-based invites the acceptor's verified email must match the |
| 184 | // invite's target_email. |
| 185 | func AcceptInvitation(ctx context.Context, deps Deps, inv orgsdb.OrgInvitation, acceptorUserID int64) error { |
| 186 | if err := validatePending(inv); err != nil { |
| 187 | return err |
| 188 | } |
| 189 | // Email-target invites: claim only when the acceptor owns the |
| 190 | // email (verified primary). Username-target invites: only the |
| 191 | // matching user can accept. |
| 192 | if inv.TargetUserID.Valid && inv.TargetUserID.Int64 != acceptorUserID { |
| 193 | return ErrUnauthorizedAcceptor |
| 194 | } |
| 195 | if inv.TargetEmail.Valid { |
| 196 | ok, err := userOwnsVerifiedEmail(ctx, deps, acceptorUserID, string(inv.TargetEmail.String)) |
| 197 | if err != nil { |
| 198 | return err |
| 199 | } |
| 200 | if !ok { |
| 201 | return ErrUnauthorizedAcceptor |
| 202 | } |
| 203 | } |
| 204 | if inv.Role == orgsdb.OrgRoleOwner { |
| 205 | check, err := entitlements.CheckOrgOwnerPrivateCollaboration(ctx, entitlements.Deps{Pool: deps.Pool}, inv.OrgID, acceptorUserID) |
| 206 | if err != nil { |
| 207 | return err |
| 208 | } |
| 209 | if err := check.Err(); err != nil { |
| 210 | return err |
| 211 | } |
| 212 | } |
| 213 | |
| 214 | tx, err := deps.Pool.Begin(ctx) |
| 215 | if err != nil { |
| 216 | return err |
| 217 | } |
| 218 | committed := false |
| 219 | defer func() { |
| 220 | if !committed { |
| 221 | _ = tx.Rollback(ctx) |
| 222 | } |
| 223 | }() |
| 224 | q := orgsdb.New() |
| 225 | if err := q.AddOrgMember(ctx, tx, orgsdb.AddOrgMemberParams{ |
| 226 | OrgID: inv.OrgID, |
| 227 | UserID: acceptorUserID, |
| 228 | Role: inv.Role, |
| 229 | InvitedByUserID: inv.InvitedByUserID, |
| 230 | }); err != nil { |
| 231 | return fmt.Errorf("add member: %w", err) |
| 232 | } |
| 233 | if err := q.AcceptOrgInvitation(ctx, tx, inv.ID); err != nil { |
| 234 | return fmt.Errorf("mark accepted: %w", err) |
| 235 | } |
| 236 | if err := enqueueBillingSeatSync(ctx, tx, deps, inv.OrgID); err != nil { |
| 237 | return fmt.Errorf("enqueue billing seat sync: %w", err) |
| 238 | } |
| 239 | if err := tx.Commit(ctx); err != nil { |
| 240 | return err |
| 241 | } |
| 242 | committed = true |
| 243 | return nil |
| 244 | } |
| 245 | |
| 246 | // DeclineInvitation marks an invitation declined. Same target-binding |
| 247 | // rules as AcceptInvitation apply. |
| 248 | func DeclineInvitation(ctx context.Context, deps Deps, inv orgsdb.OrgInvitation, declinerUserID int64) error { |
| 249 | if err := validatePending(inv); err != nil { |
| 250 | return err |
| 251 | } |
| 252 | if inv.TargetUserID.Valid && inv.TargetUserID.Int64 != declinerUserID { |
| 253 | return ErrUnauthorizedAcceptor |
| 254 | } |
| 255 | if inv.TargetEmail.Valid { |
| 256 | ok, err := userOwnsVerifiedEmail(ctx, deps, declinerUserID, string(inv.TargetEmail.String)) |
| 257 | if err != nil { |
| 258 | return err |
| 259 | } |
| 260 | if !ok { |
| 261 | return ErrUnauthorizedAcceptor |
| 262 | } |
| 263 | } |
| 264 | return orgsdb.New().DeclineOrgInvitation(ctx, deps.Pool, inv.ID) |
| 265 | } |
| 266 | |
| 267 | // CancelInvitation marks an invitation canceled. Caller is the org |
| 268 | // owner / inviter; policy is checked before this is invoked. |
| 269 | func CancelInvitation(ctx context.Context, deps Deps, inv orgsdb.OrgInvitation) error { |
| 270 | if err := validatePending(inv); err != nil { |
| 271 | return err |
| 272 | } |
| 273 | return orgsdb.New().CancelOrgInvitation(ctx, deps.Pool, inv.ID) |
| 274 | } |
| 275 | |
| 276 | // LookupInvitationByToken fetches the invitation matching the |
| 277 | // supplied bearer token. Returns ErrInvitationNotFound when the token |
| 278 | // doesn't match a row. |
| 279 | func LookupInvitationByToken(ctx context.Context, deps Deps, encodedToken string) (orgsdb.OrgInvitation, error) { |
| 280 | hash, err := token.HashOf(encodedToken) |
| 281 | if err != nil { |
| 282 | return orgsdb.OrgInvitation{}, ErrInvitationNotFound |
| 283 | } |
| 284 | row, err := orgsdb.New().GetOrgInvitationByTokenHash(ctx, deps.Pool, hash) |
| 285 | if err != nil { |
| 286 | if errors.Is(err, pgx.ErrNoRows) { |
| 287 | return orgsdb.OrgInvitation{}, ErrInvitationNotFound |
| 288 | } |
| 289 | return orgsdb.OrgInvitation{}, err |
| 290 | } |
| 291 | return row, nil |
| 292 | } |
| 293 | |
| 294 | // ─── helpers ─────────────────────────────────────────────────────── |
| 295 | |
| 296 | // ErrUnauthorizedAcceptor is returned when an acceptor doesn't own |
| 297 | // the email the invite was issued to (or doesn't match the invited |
| 298 | // user). |
| 299 | var ErrUnauthorizedAcceptor = errors.New("orgs: invitation does not match this user") |
| 300 | |
| 301 | func validatePending(inv orgsdb.OrgInvitation) error { |
| 302 | if inv.AcceptedAt.Valid || inv.DeclinedAt.Valid || inv.CanceledAt.Valid { |
| 303 | return ErrInvitationConsumed |
| 304 | } |
| 305 | if !inv.ExpiresAt.Valid || inv.ExpiresAt.Time.Before(time.Now()) { |
| 306 | return ErrInvitationExpired |
| 307 | } |
| 308 | return nil |
| 309 | } |
| 310 | |
| 311 | func userOwnsVerifiedEmail(ctx context.Context, deps Deps, userID int64, email string) (bool, error) { |
| 312 | rows, err := usersdb.New().ListUserEmailsForUser(ctx, deps.Pool, userID) |
| 313 | if err != nil { |
| 314 | return false, err |
| 315 | } |
| 316 | low := strings.ToLower(email) |
| 317 | for _, e := range rows { |
| 318 | if strings.ToLower(string(e.Email)) == low && e.Verified { |
| 319 | return true, nil |
| 320 | } |
| 321 | } |
| 322 | return false, nil |
| 323 | } |
| 324 | |
| 325 | // emailToCitext is a tiny shim because the sqlc-generated email |
| 326 | // column type for `target_email` is a non-pointer pgtype.Text-flavored |
| 327 | // citext. Pass-through when valid; zero value otherwise. |
| 328 | func emailToCitext(p pgtype.Text) pgtype.Text { |
| 329 | if !p.Valid { |
| 330 | return pgtype.Text{Valid: false} |
| 331 | } |
| 332 | return p |
| 333 | } |
| 334 |