@@ -0,0 +1,306 @@ |
| 1 | +// SPDX-License-Identifier: AGPL-3.0-or-later |
| 2 | + |
| 3 | +package orgs |
| 4 | + |
| 5 | +import ( |
| 6 | + "context" |
| 7 | + "errors" |
| 8 | + "fmt" |
| 9 | + "strings" |
| 10 | + "time" |
| 11 | + |
| 12 | + "github.com/jackc/pgx/v5" |
| 13 | + "github.com/jackc/pgx/v5/pgtype" |
| 14 | + |
| 15 | + "github.com/tenseleyFlow/shithub/internal/auth/email" |
| 16 | + "github.com/tenseleyFlow/shithub/internal/auth/token" |
| 17 | + orgsdb "github.com/tenseleyFlow/shithub/internal/orgs/sqlc" |
| 18 | + usersdb "github.com/tenseleyFlow/shithub/internal/users/sqlc" |
| 19 | +) |
| 20 | + |
| 21 | +// InviteParams describes a new invitation. Exactly one of |
| 22 | +// TargetUsername / TargetEmail must be non-empty; the resolver below |
| 23 | +// converts a username to a user_id via the users table. |
| 24 | +type InviteParams struct { |
| 25 | + OrgID int64 |
| 26 | + InvitedByUserID int64 |
| 27 | + TargetUsername string |
| 28 | + TargetEmail string |
| 29 | + Role string // "owner" | "member" |
| 30 | +} |
| 31 | + |
| 32 | +// InviteResult bundles the freshly-created invitation row + the |
| 33 | +// plaintext token. Callers typically render the token into a URL and |
| 34 | +// drop it into an email; the token IS the credential — never persist |
| 35 | +// it after the email send. |
| 36 | +type InviteResult struct { |
| 37 | + Invitation orgsdb.OrgInvitation |
| 38 | + Token string |
| 39 | +} |
| 40 | + |
| 41 | +// Invite creates a pending invitation. Resolves a username to a |
| 42 | +// user_id when provided; otherwise treats the input as an email |
| 43 | +// address (recipients without an account claim it on signup). |
| 44 | +func Invite(ctx context.Context, deps Deps, p InviteParams) (InviteResult, error) { |
| 45 | + if p.OrgID == 0 || p.InvitedByUserID == 0 { |
| 46 | + return InviteResult{}, errors.New("orgs: OrgID + InvitedByUserID required") |
| 47 | + } |
| 48 | + if (p.TargetUsername == "" && p.TargetEmail == "") || |
| 49 | + (p.TargetUsername != "" && p.TargetEmail != "") { |
| 50 | + return InviteResult{}, ErrInvalidInvitationKind |
| 51 | + } |
| 52 | + role, err := parseRole(p.Role) |
| 53 | + if err != nil { |
| 54 | + return InviteResult{}, err |
| 55 | + } |
| 56 | + |
| 57 | + q := orgsdb.New() |
| 58 | + |
| 59 | + var ( |
| 60 | + targetUserID pgtype.Int8 |
| 61 | + targetEmail pgtype.Text |
| 62 | + ) |
| 63 | + if p.TargetUsername != "" { |
| 64 | + uname := strings.ToLower(strings.TrimSpace(p.TargetUsername)) |
| 65 | + u, err := usersdb.New().GetUserByUsername(ctx, deps.Pool, uname) |
| 66 | + if err != nil { |
| 67 | + if errors.Is(err, pgx.ErrNoRows) { |
| 68 | + return InviteResult{}, ErrUserNotFound |
| 69 | + } |
| 70 | + return InviteResult{}, err |
| 71 | + } |
| 72 | + // Already a member? short-circuit. |
| 73 | + if _, err := q.GetOrgMember(ctx, deps.Pool, orgsdb.GetOrgMemberParams{ |
| 74 | + OrgID: p.OrgID, UserID: u.ID, |
| 75 | + }); err == nil { |
| 76 | + return InviteResult{}, ErrAlreadyMember |
| 77 | + } |
| 78 | + targetUserID = pgtype.Int8{Int64: u.ID, Valid: true} |
| 79 | + } else { |
| 80 | + email := strings.ToLower(strings.TrimSpace(p.TargetEmail)) |
| 81 | + if email == "" { |
| 82 | + return InviteResult{}, ErrInvalidInvitationKind |
| 83 | + } |
| 84 | + targetEmail = pgtype.Text{String: email, Valid: true} |
| 85 | + } |
| 86 | + |
| 87 | + // Idempotency: if a pending invite for this target already exists, |
| 88 | + // surface it rather than minting a fresh one (avoids token churn |
| 89 | + // + accidental spam from re-clicked Invite buttons). |
| 90 | + if _, err := q.GetExistingPendingInvitation(ctx, deps.Pool, orgsdb.GetExistingPendingInvitationParams{ |
| 91 | + OrgID: p.OrgID, |
| 92 | + TargetUserID: targetUserID, |
| 93 | + TargetEmail: emailToCitext(targetEmail), |
| 94 | + }); err == nil { |
| 95 | + return InviteResult{}, ErrInvitationDuplicate |
| 96 | + } else if !errors.Is(err, pgx.ErrNoRows) { |
| 97 | + return InviteResult{}, err |
| 98 | + } |
| 99 | + |
| 100 | + tokEnc, tokHash, err := token.New() |
| 101 | + if err != nil { |
| 102 | + return InviteResult{}, fmt.Errorf("invite token: %w", err) |
| 103 | + } |
| 104 | + row, err := q.CreateOrgInvitation(ctx, deps.Pool, orgsdb.CreateOrgInvitationParams{ |
| 105 | + OrgID: p.OrgID, |
| 106 | + InvitedByUserID: pgtype.Int8{Int64: p.InvitedByUserID, Valid: true}, |
| 107 | + TargetUserID: targetUserID, |
| 108 | + TargetEmail: emailToCitext(targetEmail), |
| 109 | + Role: role, |
| 110 | + TokenHash: tokHash, |
| 111 | + ExpiresAt: pgtype.Timestamptz{Time: time.Now().Add(7 * 24 * time.Hour), Valid: true}, |
| 112 | + }) |
| 113 | + if err != nil { |
| 114 | + return InviteResult{}, fmt.Errorf("create invitation: %w", err) |
| 115 | + } |
| 116 | + |
| 117 | + // Best-effort email — failures don't break the invitation row. |
| 118 | + go h.tryEmailInvite(deps, row, tokEnc) //nolint:gocritic // closure is fine |
| 119 | + |
| 120 | + return InviteResult{Invitation: row, Token: tokEnc}, nil |
| 121 | +} |
| 122 | + |
| 123 | +// h is a tiny package-private helper struct so the closure in Invite |
| 124 | +// can pin the email-side code without dragging extra fields into |
| 125 | +// every call site. Keeps Invite signature focused. |
| 126 | +var h struct { |
| 127 | + tryEmailInvite func(deps Deps, row orgsdb.OrgInvitation, tokEnc string) |
| 128 | +} |
| 129 | + |
| 130 | +func init() { |
| 131 | + h.tryEmailInvite = func(deps Deps, row orgsdb.OrgInvitation, tokEnc string) { |
| 132 | + if deps.EmailSender == nil { |
| 133 | + return |
| 134 | + } |
| 135 | + var to string |
| 136 | + if row.TargetEmail.Valid { |
| 137 | + to = string(row.TargetEmail.String) |
| 138 | + } else if row.TargetUserID.Valid { |
| 139 | + u, err := usersdb.New().GetUserByID(context.Background(), deps.Pool, row.TargetUserID.Int64) |
| 140 | + if err != nil || !u.PrimaryEmailID.Valid { |
| 141 | + return |
| 142 | + } |
| 143 | + em, err := usersdb.New().GetUserEmailByID(context.Background(), deps.Pool, u.PrimaryEmailID.Int64) |
| 144 | + if err != nil || !em.Verified { |
| 145 | + return |
| 146 | + } |
| 147 | + to = string(em.Email) |
| 148 | + } |
| 149 | + if to == "" { |
| 150 | + return |
| 151 | + } |
| 152 | + url := strings.TrimRight(deps.BaseURL, "/") + "/invitations/" + tokEnc |
| 153 | + text := "You've been invited to join an organization on " + deps.SiteName + ".\n\n" + |
| 154 | + "Accept or decline: " + url + "\n\n" + |
| 155 | + "This link expires in 7 days.\n" |
| 156 | + _ = deps.EmailSender.Send(context.Background(), email.Message{ |
| 157 | + From: deps.EmailFrom, |
| 158 | + To: to, |
| 159 | + Subject: "[" + deps.SiteName + "] You've been invited to an organization", |
| 160 | + Text: text, |
| 161 | + HTML: "<p>You've been invited to join an organization on " + deps.SiteName + ".</p><p><a href=\"" + url + "\">Accept or decline</a></p><p>This link expires in 7 days.</p>", |
| 162 | + }) |
| 163 | + } |
| 164 | +} |
| 165 | + |
| 166 | +// AcceptInvitation marks an invitation accepted and adds the target |
| 167 | +// as a member. The acceptor is the currently-logged-in user; for |
| 168 | +// email-based invites the acceptor's verified email must match the |
| 169 | +// invite's target_email. |
| 170 | +func AcceptInvitation(ctx context.Context, deps Deps, inv orgsdb.OrgInvitation, acceptorUserID int64) error { |
| 171 | + if err := validatePending(inv); err != nil { |
| 172 | + return err |
| 173 | + } |
| 174 | + // Email-target invites: claim only when the acceptor owns the |
| 175 | + // email (verified primary). Username-target invites: only the |
| 176 | + // matching user can accept. |
| 177 | + if inv.TargetUserID.Valid && inv.TargetUserID.Int64 != acceptorUserID { |
| 178 | + return ErrUnauthorizedAcceptor |
| 179 | + } |
| 180 | + if inv.TargetEmail.Valid { |
| 181 | + ok, err := userOwnsVerifiedEmail(ctx, deps, acceptorUserID, string(inv.TargetEmail.String)) |
| 182 | + if err != nil { |
| 183 | + return err |
| 184 | + } |
| 185 | + if !ok { |
| 186 | + return ErrUnauthorizedAcceptor |
| 187 | + } |
| 188 | + } |
| 189 | + |
| 190 | + tx, err := deps.Pool.Begin(ctx) |
| 191 | + if err != nil { |
| 192 | + return err |
| 193 | + } |
| 194 | + committed := false |
| 195 | + defer func() { |
| 196 | + if !committed { |
| 197 | + _ = tx.Rollback(ctx) |
| 198 | + } |
| 199 | + }() |
| 200 | + q := orgsdb.New() |
| 201 | + if err := q.AddOrgMember(ctx, tx, orgsdb.AddOrgMemberParams{ |
| 202 | + OrgID: inv.OrgID, |
| 203 | + UserID: acceptorUserID, |
| 204 | + Role: inv.Role, |
| 205 | + InvitedByUserID: inv.InvitedByUserID, |
| 206 | + }); err != nil { |
| 207 | + return fmt.Errorf("add member: %w", err) |
| 208 | + } |
| 209 | + if err := q.AcceptOrgInvitation(ctx, tx, inv.ID); err != nil { |
| 210 | + return fmt.Errorf("mark accepted: %w", err) |
| 211 | + } |
| 212 | + if err := tx.Commit(ctx); err != nil { |
| 213 | + return err |
| 214 | + } |
| 215 | + committed = true |
| 216 | + return nil |
| 217 | +} |
| 218 | + |
| 219 | +// DeclineInvitation marks an invitation declined. Same target-binding |
| 220 | +// rules as AcceptInvitation apply. |
| 221 | +func DeclineInvitation(ctx context.Context, deps Deps, inv orgsdb.OrgInvitation, declinerUserID int64) error { |
| 222 | + if err := validatePending(inv); err != nil { |
| 223 | + return err |
| 224 | + } |
| 225 | + if inv.TargetUserID.Valid && inv.TargetUserID.Int64 != declinerUserID { |
| 226 | + return ErrUnauthorizedAcceptor |
| 227 | + } |
| 228 | + if inv.TargetEmail.Valid { |
| 229 | + ok, err := userOwnsVerifiedEmail(ctx, deps, declinerUserID, string(inv.TargetEmail.String)) |
| 230 | + if err != nil { |
| 231 | + return err |
| 232 | + } |
| 233 | + if !ok { |
| 234 | + return ErrUnauthorizedAcceptor |
| 235 | + } |
| 236 | + } |
| 237 | + return orgsdb.New().DeclineOrgInvitation(ctx, deps.Pool, inv.ID) |
| 238 | +} |
| 239 | + |
| 240 | +// CancelInvitation marks an invitation canceled. Caller is the org |
| 241 | +// owner / inviter; policy is checked before this is invoked. |
| 242 | +func CancelInvitation(ctx context.Context, deps Deps, inv orgsdb.OrgInvitation) error { |
| 243 | + if err := validatePending(inv); err != nil { |
| 244 | + return err |
| 245 | + } |
| 246 | + return orgsdb.New().CancelOrgInvitation(ctx, deps.Pool, inv.ID) |
| 247 | +} |
| 248 | + |
| 249 | +// LookupInvitationByToken fetches the invitation matching the |
| 250 | +// supplied bearer token. Returns ErrInvitationNotFound when the token |
| 251 | +// doesn't match a row. |
| 252 | +func LookupInvitationByToken(ctx context.Context, deps Deps, encodedToken string) (orgsdb.OrgInvitation, error) { |
| 253 | + hash, err := token.HashOf(encodedToken) |
| 254 | + if err != nil { |
| 255 | + return orgsdb.OrgInvitation{}, ErrInvitationNotFound |
| 256 | + } |
| 257 | + row, err := orgsdb.New().GetOrgInvitationByTokenHash(ctx, deps.Pool, hash) |
| 258 | + if err != nil { |
| 259 | + if errors.Is(err, pgx.ErrNoRows) { |
| 260 | + return orgsdb.OrgInvitation{}, ErrInvitationNotFound |
| 261 | + } |
| 262 | + return orgsdb.OrgInvitation{}, err |
| 263 | + } |
| 264 | + return row, nil |
| 265 | +} |
| 266 | + |
| 267 | +// ─── helpers ─────────────────────────────────────────────────────── |
| 268 | + |
| 269 | +// ErrUnauthorizedAcceptor is returned when an acceptor doesn't own |
| 270 | +// the email the invite was issued to (or doesn't match the invited |
| 271 | +// user). |
| 272 | +var ErrUnauthorizedAcceptor = errors.New("orgs: invitation does not match this user") |
| 273 | + |
| 274 | +func validatePending(inv orgsdb.OrgInvitation) error { |
| 275 | + if inv.AcceptedAt.Valid || inv.DeclinedAt.Valid || inv.CanceledAt.Valid { |
| 276 | + return ErrInvitationConsumed |
| 277 | + } |
| 278 | + if !inv.ExpiresAt.Valid || inv.ExpiresAt.Time.Before(time.Now()) { |
| 279 | + return ErrInvitationExpired |
| 280 | + } |
| 281 | + return nil |
| 282 | +} |
| 283 | + |
| 284 | +func userOwnsVerifiedEmail(ctx context.Context, deps Deps, userID int64, email string) (bool, error) { |
| 285 | + rows, err := usersdb.New().ListUserEmailsForUser(ctx, deps.Pool, userID) |
| 286 | + if err != nil { |
| 287 | + return false, err |
| 288 | + } |
| 289 | + low := strings.ToLower(email) |
| 290 | + for _, e := range rows { |
| 291 | + if strings.ToLower(string(e.Email)) == low && e.Verified { |
| 292 | + return true, nil |
| 293 | + } |
| 294 | + } |
| 295 | + return false, nil |
| 296 | +} |
| 297 | + |
| 298 | +// emailToCitext is a tiny shim because the sqlc-generated email |
| 299 | +// column type for `target_email` is a non-pointer pgtype.Text-flavored |
| 300 | +// citext. Pass-through when valid; zero value otherwise. |
| 301 | +func emailToCitext(p pgtype.Text) pgtype.Text { |
| 302 | + if !p.Valid { |
| 303 | + return pgtype.Text{Valid: false} |
| 304 | + } |
| 305 | + return p |
| 306 | +} |