@@ -0,0 +1,306 @@ |
| | 1 | +// SPDX-License-Identifier: AGPL-3.0-or-later |
| | 2 | + |
| | 3 | +package orgs |
| | 4 | + |
| | 5 | +import ( |
| | 6 | + "context" |
| | 7 | + "errors" |
| | 8 | + "fmt" |
| | 9 | + "strings" |
| | 10 | + "time" |
| | 11 | + |
| | 12 | + "github.com/jackc/pgx/v5" |
| | 13 | + "github.com/jackc/pgx/v5/pgtype" |
| | 14 | + |
| | 15 | + "github.com/tenseleyFlow/shithub/internal/auth/email" |
| | 16 | + "github.com/tenseleyFlow/shithub/internal/auth/token" |
| | 17 | + orgsdb "github.com/tenseleyFlow/shithub/internal/orgs/sqlc" |
| | 18 | + usersdb "github.com/tenseleyFlow/shithub/internal/users/sqlc" |
| | 19 | +) |
| | 20 | + |
| | 21 | +// InviteParams describes a new invitation. Exactly one of |
| | 22 | +// TargetUsername / TargetEmail must be non-empty; the resolver below |
| | 23 | +// converts a username to a user_id via the users table. |
| | 24 | +type InviteParams struct { |
| | 25 | + OrgID int64 |
| | 26 | + InvitedByUserID int64 |
| | 27 | + TargetUsername string |
| | 28 | + TargetEmail string |
| | 29 | + Role string // "owner" | "member" |
| | 30 | +} |
| | 31 | + |
| | 32 | +// InviteResult bundles the freshly-created invitation row + the |
| | 33 | +// plaintext token. Callers typically render the token into a URL and |
| | 34 | +// drop it into an email; the token IS the credential — never persist |
| | 35 | +// it after the email send. |
| | 36 | +type InviteResult struct { |
| | 37 | + Invitation orgsdb.OrgInvitation |
| | 38 | + Token string |
| | 39 | +} |
| | 40 | + |
| | 41 | +// Invite creates a pending invitation. Resolves a username to a |
| | 42 | +// user_id when provided; otherwise treats the input as an email |
| | 43 | +// address (recipients without an account claim it on signup). |
| | 44 | +func Invite(ctx context.Context, deps Deps, p InviteParams) (InviteResult, error) { |
| | 45 | + if p.OrgID == 0 || p.InvitedByUserID == 0 { |
| | 46 | + return InviteResult{}, errors.New("orgs: OrgID + InvitedByUserID required") |
| | 47 | + } |
| | 48 | + if (p.TargetUsername == "" && p.TargetEmail == "") || |
| | 49 | + (p.TargetUsername != "" && p.TargetEmail != "") { |
| | 50 | + return InviteResult{}, ErrInvalidInvitationKind |
| | 51 | + } |
| | 52 | + role, err := parseRole(p.Role) |
| | 53 | + if err != nil { |
| | 54 | + return InviteResult{}, err |
| | 55 | + } |
| | 56 | + |
| | 57 | + q := orgsdb.New() |
| | 58 | + |
| | 59 | + var ( |
| | 60 | + targetUserID pgtype.Int8 |
| | 61 | + targetEmail pgtype.Text |
| | 62 | + ) |
| | 63 | + if p.TargetUsername != "" { |
| | 64 | + uname := strings.ToLower(strings.TrimSpace(p.TargetUsername)) |
| | 65 | + u, err := usersdb.New().GetUserByUsername(ctx, deps.Pool, uname) |
| | 66 | + if err != nil { |
| | 67 | + if errors.Is(err, pgx.ErrNoRows) { |
| | 68 | + return InviteResult{}, ErrUserNotFound |
| | 69 | + } |
| | 70 | + return InviteResult{}, err |
| | 71 | + } |
| | 72 | + // Already a member? short-circuit. |
| | 73 | + if _, err := q.GetOrgMember(ctx, deps.Pool, orgsdb.GetOrgMemberParams{ |
| | 74 | + OrgID: p.OrgID, UserID: u.ID, |
| | 75 | + }); err == nil { |
| | 76 | + return InviteResult{}, ErrAlreadyMember |
| | 77 | + } |
| | 78 | + targetUserID = pgtype.Int8{Int64: u.ID, Valid: true} |
| | 79 | + } else { |
| | 80 | + email := strings.ToLower(strings.TrimSpace(p.TargetEmail)) |
| | 81 | + if email == "" { |
| | 82 | + return InviteResult{}, ErrInvalidInvitationKind |
| | 83 | + } |
| | 84 | + targetEmail = pgtype.Text{String: email, Valid: true} |
| | 85 | + } |
| | 86 | + |
| | 87 | + // Idempotency: if a pending invite for this target already exists, |
| | 88 | + // surface it rather than minting a fresh one (avoids token churn |
| | 89 | + // + accidental spam from re-clicked Invite buttons). |
| | 90 | + if _, err := q.GetExistingPendingInvitation(ctx, deps.Pool, orgsdb.GetExistingPendingInvitationParams{ |
| | 91 | + OrgID: p.OrgID, |
| | 92 | + TargetUserID: targetUserID, |
| | 93 | + TargetEmail: emailToCitext(targetEmail), |
| | 94 | + }); err == nil { |
| | 95 | + return InviteResult{}, ErrInvitationDuplicate |
| | 96 | + } else if !errors.Is(err, pgx.ErrNoRows) { |
| | 97 | + return InviteResult{}, err |
| | 98 | + } |
| | 99 | + |
| | 100 | + tokEnc, tokHash, err := token.New() |
| | 101 | + if err != nil { |
| | 102 | + return InviteResult{}, fmt.Errorf("invite token: %w", err) |
| | 103 | + } |
| | 104 | + row, err := q.CreateOrgInvitation(ctx, deps.Pool, orgsdb.CreateOrgInvitationParams{ |
| | 105 | + OrgID: p.OrgID, |
| | 106 | + InvitedByUserID: pgtype.Int8{Int64: p.InvitedByUserID, Valid: true}, |
| | 107 | + TargetUserID: targetUserID, |
| | 108 | + TargetEmail: emailToCitext(targetEmail), |
| | 109 | + Role: role, |
| | 110 | + TokenHash: tokHash, |
| | 111 | + ExpiresAt: pgtype.Timestamptz{Time: time.Now().Add(7 * 24 * time.Hour), Valid: true}, |
| | 112 | + }) |
| | 113 | + if err != nil { |
| | 114 | + return InviteResult{}, fmt.Errorf("create invitation: %w", err) |
| | 115 | + } |
| | 116 | + |
| | 117 | + // Best-effort email — failures don't break the invitation row. |
| | 118 | + go h.tryEmailInvite(deps, row, tokEnc) //nolint:gocritic // closure is fine |
| | 119 | + |
| | 120 | + return InviteResult{Invitation: row, Token: tokEnc}, nil |
| | 121 | +} |
| | 122 | + |
| | 123 | +// h is a tiny package-private helper struct so the closure in Invite |
| | 124 | +// can pin the email-side code without dragging extra fields into |
| | 125 | +// every call site. Keeps Invite signature focused. |
| | 126 | +var h struct { |
| | 127 | + tryEmailInvite func(deps Deps, row orgsdb.OrgInvitation, tokEnc string) |
| | 128 | +} |
| | 129 | + |
| | 130 | +func init() { |
| | 131 | + h.tryEmailInvite = func(deps Deps, row orgsdb.OrgInvitation, tokEnc string) { |
| | 132 | + if deps.EmailSender == nil { |
| | 133 | + return |
| | 134 | + } |
| | 135 | + var to string |
| | 136 | + if row.TargetEmail.Valid { |
| | 137 | + to = string(row.TargetEmail.String) |
| | 138 | + } else if row.TargetUserID.Valid { |
| | 139 | + u, err := usersdb.New().GetUserByID(context.Background(), deps.Pool, row.TargetUserID.Int64) |
| | 140 | + if err != nil || !u.PrimaryEmailID.Valid { |
| | 141 | + return |
| | 142 | + } |
| | 143 | + em, err := usersdb.New().GetUserEmailByID(context.Background(), deps.Pool, u.PrimaryEmailID.Int64) |
| | 144 | + if err != nil || !em.Verified { |
| | 145 | + return |
| | 146 | + } |
| | 147 | + to = string(em.Email) |
| | 148 | + } |
| | 149 | + if to == "" { |
| | 150 | + return |
| | 151 | + } |
| | 152 | + url := strings.TrimRight(deps.BaseURL, "/") + "/invitations/" + tokEnc |
| | 153 | + text := "You've been invited to join an organization on " + deps.SiteName + ".\n\n" + |
| | 154 | + "Accept or decline: " + url + "\n\n" + |
| | 155 | + "This link expires in 7 days.\n" |
| | 156 | + _ = deps.EmailSender.Send(context.Background(), email.Message{ |
| | 157 | + From: deps.EmailFrom, |
| | 158 | + To: to, |
| | 159 | + Subject: "[" + deps.SiteName + "] You've been invited to an organization", |
| | 160 | + Text: text, |
| | 161 | + HTML: "<p>You've been invited to join an organization on " + deps.SiteName + ".</p><p><a href=\"" + url + "\">Accept or decline</a></p><p>This link expires in 7 days.</p>", |
| | 162 | + }) |
| | 163 | + } |
| | 164 | +} |
| | 165 | + |
| | 166 | +// AcceptInvitation marks an invitation accepted and adds the target |
| | 167 | +// as a member. The acceptor is the currently-logged-in user; for |
| | 168 | +// email-based invites the acceptor's verified email must match the |
| | 169 | +// invite's target_email. |
| | 170 | +func AcceptInvitation(ctx context.Context, deps Deps, inv orgsdb.OrgInvitation, acceptorUserID int64) error { |
| | 171 | + if err := validatePending(inv); err != nil { |
| | 172 | + return err |
| | 173 | + } |
| | 174 | + // Email-target invites: claim only when the acceptor owns the |
| | 175 | + // email (verified primary). Username-target invites: only the |
| | 176 | + // matching user can accept. |
| | 177 | + if inv.TargetUserID.Valid && inv.TargetUserID.Int64 != acceptorUserID { |
| | 178 | + return ErrUnauthorizedAcceptor |
| | 179 | + } |
| | 180 | + if inv.TargetEmail.Valid { |
| | 181 | + ok, err := userOwnsVerifiedEmail(ctx, deps, acceptorUserID, string(inv.TargetEmail.String)) |
| | 182 | + if err != nil { |
| | 183 | + return err |
| | 184 | + } |
| | 185 | + if !ok { |
| | 186 | + return ErrUnauthorizedAcceptor |
| | 187 | + } |
| | 188 | + } |
| | 189 | + |
| | 190 | + tx, err := deps.Pool.Begin(ctx) |
| | 191 | + if err != nil { |
| | 192 | + return err |
| | 193 | + } |
| | 194 | + committed := false |
| | 195 | + defer func() { |
| | 196 | + if !committed { |
| | 197 | + _ = tx.Rollback(ctx) |
| | 198 | + } |
| | 199 | + }() |
| | 200 | + q := orgsdb.New() |
| | 201 | + if err := q.AddOrgMember(ctx, tx, orgsdb.AddOrgMemberParams{ |
| | 202 | + OrgID: inv.OrgID, |
| | 203 | + UserID: acceptorUserID, |
| | 204 | + Role: inv.Role, |
| | 205 | + InvitedByUserID: inv.InvitedByUserID, |
| | 206 | + }); err != nil { |
| | 207 | + return fmt.Errorf("add member: %w", err) |
| | 208 | + } |
| | 209 | + if err := q.AcceptOrgInvitation(ctx, tx, inv.ID); err != nil { |
| | 210 | + return fmt.Errorf("mark accepted: %w", err) |
| | 211 | + } |
| | 212 | + if err := tx.Commit(ctx); err != nil { |
| | 213 | + return err |
| | 214 | + } |
| | 215 | + committed = true |
| | 216 | + return nil |
| | 217 | +} |
| | 218 | + |
| | 219 | +// DeclineInvitation marks an invitation declined. Same target-binding |
| | 220 | +// rules as AcceptInvitation apply. |
| | 221 | +func DeclineInvitation(ctx context.Context, deps Deps, inv orgsdb.OrgInvitation, declinerUserID int64) error { |
| | 222 | + if err := validatePending(inv); err != nil { |
| | 223 | + return err |
| | 224 | + } |
| | 225 | + if inv.TargetUserID.Valid && inv.TargetUserID.Int64 != declinerUserID { |
| | 226 | + return ErrUnauthorizedAcceptor |
| | 227 | + } |
| | 228 | + if inv.TargetEmail.Valid { |
| | 229 | + ok, err := userOwnsVerifiedEmail(ctx, deps, declinerUserID, string(inv.TargetEmail.String)) |
| | 230 | + if err != nil { |
| | 231 | + return err |
| | 232 | + } |
| | 233 | + if !ok { |
| | 234 | + return ErrUnauthorizedAcceptor |
| | 235 | + } |
| | 236 | + } |
| | 237 | + return orgsdb.New().DeclineOrgInvitation(ctx, deps.Pool, inv.ID) |
| | 238 | +} |
| | 239 | + |
| | 240 | +// CancelInvitation marks an invitation canceled. Caller is the org |
| | 241 | +// owner / inviter; policy is checked before this is invoked. |
| | 242 | +func CancelInvitation(ctx context.Context, deps Deps, inv orgsdb.OrgInvitation) error { |
| | 243 | + if err := validatePending(inv); err != nil { |
| | 244 | + return err |
| | 245 | + } |
| | 246 | + return orgsdb.New().CancelOrgInvitation(ctx, deps.Pool, inv.ID) |
| | 247 | +} |
| | 248 | + |
| | 249 | +// LookupInvitationByToken fetches the invitation matching the |
| | 250 | +// supplied bearer token. Returns ErrInvitationNotFound when the token |
| | 251 | +// doesn't match a row. |
| | 252 | +func LookupInvitationByToken(ctx context.Context, deps Deps, encodedToken string) (orgsdb.OrgInvitation, error) { |
| | 253 | + hash, err := token.HashOf(encodedToken) |
| | 254 | + if err != nil { |
| | 255 | + return orgsdb.OrgInvitation{}, ErrInvitationNotFound |
| | 256 | + } |
| | 257 | + row, err := orgsdb.New().GetOrgInvitationByTokenHash(ctx, deps.Pool, hash) |
| | 258 | + if err != nil { |
| | 259 | + if errors.Is(err, pgx.ErrNoRows) { |
| | 260 | + return orgsdb.OrgInvitation{}, ErrInvitationNotFound |
| | 261 | + } |
| | 262 | + return orgsdb.OrgInvitation{}, err |
| | 263 | + } |
| | 264 | + return row, nil |
| | 265 | +} |
| | 266 | + |
| | 267 | +// ─── helpers ─────────────────────────────────────────────────────── |
| | 268 | + |
| | 269 | +// ErrUnauthorizedAcceptor is returned when an acceptor doesn't own |
| | 270 | +// the email the invite was issued to (or doesn't match the invited |
| | 271 | +// user). |
| | 272 | +var ErrUnauthorizedAcceptor = errors.New("orgs: invitation does not match this user") |
| | 273 | + |
| | 274 | +func validatePending(inv orgsdb.OrgInvitation) error { |
| | 275 | + if inv.AcceptedAt.Valid || inv.DeclinedAt.Valid || inv.CanceledAt.Valid { |
| | 276 | + return ErrInvitationConsumed |
| | 277 | + } |
| | 278 | + if !inv.ExpiresAt.Valid || inv.ExpiresAt.Time.Before(time.Now()) { |
| | 279 | + return ErrInvitationExpired |
| | 280 | + } |
| | 281 | + return nil |
| | 282 | +} |
| | 283 | + |
| | 284 | +func userOwnsVerifiedEmail(ctx context.Context, deps Deps, userID int64, email string) (bool, error) { |
| | 285 | + rows, err := usersdb.New().ListUserEmailsForUser(ctx, deps.Pool, userID) |
| | 286 | + if err != nil { |
| | 287 | + return false, err |
| | 288 | + } |
| | 289 | + low := strings.ToLower(email) |
| | 290 | + for _, e := range rows { |
| | 291 | + if strings.ToLower(string(e.Email)) == low && e.Verified { |
| | 292 | + return true, nil |
| | 293 | + } |
| | 294 | + } |
| | 295 | + return false, nil |
| | 296 | +} |
| | 297 | + |
| | 298 | +// emailToCitext is a tiny shim because the sqlc-generated email |
| | 299 | +// column type for `target_email` is a non-pointer pgtype.Text-flavored |
| | 300 | +// citext. Pass-through when valid; zero value otherwise. |
| | 301 | +func emailToCitext(p pgtype.Text) pgtype.Text { |
| | 302 | + if !p.Valid { |
| | 303 | + return pgtype.Text{Valid: false} |
| | 304 | + } |
| | 305 | + return p |
| | 306 | +} |