| 1 | // SPDX-License-Identifier: AGPL-3.0-or-later |
| 2 | |
| 3 | package orgs |
| 4 | |
| 5 | import ( |
| 6 | "bytes" |
| 7 | "context" |
| 8 | "encoding/json" |
| 9 | "errors" |
| 10 | "fmt" |
| 11 | "io" |
| 12 | "log/slog" |
| 13 | "net/http" |
| 14 | "net/url" |
| 15 | "regexp" |
| 16 | "strconv" |
| 17 | "strings" |
| 18 | "time" |
| 19 | |
| 20 | "github.com/jackc/pgx/v5" |
| 21 | "github.com/jackc/pgx/v5/pgtype" |
| 22 | "github.com/jackc/pgx/v5/pgxpool" |
| 23 | |
| 24 | "github.com/tenseleyFlow/shithub/internal/auth/secretbox" |
| 25 | orgsdb "github.com/tenseleyFlow/shithub/internal/orgs/sqlc" |
| 26 | "github.com/tenseleyFlow/shithub/internal/worker" |
| 27 | ) |
| 28 | |
| 29 | const GitHubHost = "github.com" |
| 30 | |
| 31 | const ( |
| 32 | ImportStatusQueued = "queued" |
| 33 | ImportStatusDiscovering = "discovering" |
| 34 | ImportStatusImporting = "importing" |
| 35 | ImportStatusCompleted = "completed" |
| 36 | ImportStatusFailed = "failed" |
| 37 | |
| 38 | ImportRepoStatusQueued = "queued" |
| 39 | ImportRepoStatusImporting = "importing" |
| 40 | ImportRepoStatusImported = "imported" |
| 41 | ImportRepoStatusSkipped = "skipped" |
| 42 | ImportRepoStatusFailed = "failed" |
| 43 | ) |
| 44 | |
| 45 | var ( |
| 46 | ErrInvalidGitHubOrg = errors.New("orgs: invalid GitHub organization") |
| 47 | ErrImportTokenKeyNeeded = errors.New("orgs: import token encryption key is not configured") |
| 48 | ) |
| 49 | |
| 50 | var githubOrgRE = regexp.MustCompile(`^[A-Za-z0-9](?:[A-Za-z0-9-]{0,37}[A-Za-z0-9])?$`) |
| 51 | |
| 52 | // ImportDeps wires org-import orchestration. |
| 53 | type ImportDeps struct { |
| 54 | Pool *pgxpool.Pool |
| 55 | Box *secretbox.Box |
| 56 | Logger *slog.Logger |
| 57 | } |
| 58 | |
| 59 | // StartGitHubImportParams describes a single org import request. |
| 60 | type StartGitHubImportParams struct { |
| 61 | OrgID int64 |
| 62 | SourceOrg string |
| 63 | RequestedByUserID int64 |
| 64 | Token string |
| 65 | } |
| 66 | |
| 67 | // StartGitHubImport persists a GitHub import request and enqueues discovery. |
| 68 | func StartGitHubImport(ctx context.Context, deps ImportDeps, p StartGitHubImportParams) (orgsdb.OrgGithubImport, error) { |
| 69 | sourceOrg, err := NormalizeGitHubOrg(p.SourceOrg) |
| 70 | if err != nil { |
| 71 | return orgsdb.OrgGithubImport{}, err |
| 72 | } |
| 73 | token := strings.TrimSpace(p.Token) |
| 74 | var ciphertext, nonce []byte |
| 75 | tokenPresent := token != "" |
| 76 | if tokenPresent { |
| 77 | if deps.Box == nil { |
| 78 | return orgsdb.OrgGithubImport{}, ErrImportTokenKeyNeeded |
| 79 | } |
| 80 | ciphertext, nonce, err = deps.Box.Seal([]byte(token)) |
| 81 | if err != nil { |
| 82 | return orgsdb.OrgGithubImport{}, fmt.Errorf("github import: seal token: %w", err) |
| 83 | } |
| 84 | } |
| 85 | |
| 86 | tx, err := deps.Pool.Begin(ctx) |
| 87 | if err != nil { |
| 88 | return orgsdb.OrgGithubImport{}, err |
| 89 | } |
| 90 | committed := false |
| 91 | defer func() { |
| 92 | if !committed { |
| 93 | _ = tx.Rollback(ctx) |
| 94 | } |
| 95 | }() |
| 96 | |
| 97 | q := orgsdb.New() |
| 98 | row, err := q.CreateOrgGithubImport(ctx, tx, orgsdb.CreateOrgGithubImportParams{ |
| 99 | OrgID: p.OrgID, |
| 100 | SourceOrg: sourceOrg, |
| 101 | RequestedByUserID: pgtype.Int8{Int64: p.RequestedByUserID, Valid: p.RequestedByUserID != 0}, |
| 102 | IncludePrivate: tokenPresent, |
| 103 | TokenPresent: tokenPresent, |
| 104 | TokenCiphertext: ciphertext, |
| 105 | TokenNonce: nonce, |
| 106 | }) |
| 107 | if err != nil { |
| 108 | return orgsdb.OrgGithubImport{}, fmt.Errorf("github import: create: %w", err) |
| 109 | } |
| 110 | if _, err := worker.Enqueue(ctx, tx, worker.KindOrgGitHubImportDiscover, map[string]any{ |
| 111 | "import_id": row.ID, |
| 112 | }, worker.EnqueueOptions{}); err != nil { |
| 113 | return orgsdb.OrgGithubImport{}, err |
| 114 | } |
| 115 | if err := worker.Notify(ctx, tx); err != nil && deps.Logger != nil { |
| 116 | deps.Logger.WarnContext(ctx, "github import: notify", "error", err, "import_id", row.ID) |
| 117 | } |
| 118 | if err := tx.Commit(ctx); err != nil { |
| 119 | return orgsdb.OrgGithubImport{}, err |
| 120 | } |
| 121 | committed = true |
| 122 | return row, nil |
| 123 | } |
| 124 | |
| 125 | func NormalizeGitHubOrg(raw string) (string, error) { |
| 126 | org := strings.TrimSpace(raw) |
| 127 | org = strings.TrimPrefix(org, "https://github.com/") |
| 128 | org = strings.TrimPrefix(org, "http://github.com/") |
| 129 | org = strings.Trim(org, "/") |
| 130 | if org == "" || strings.Contains(org, "/") || strings.Contains(org, "--") || !githubOrgRE.MatchString(org) { |
| 131 | return "", ErrInvalidGitHubOrg |
| 132 | } |
| 133 | return org, nil |
| 134 | } |
| 135 | |
| 136 | func DecryptGitHubImportToken(row orgsdb.OrgGithubImport, box *secretbox.Box) (string, error) { |
| 137 | if len(row.TokenCiphertext) == 0 && len(row.TokenNonce) == 0 { |
| 138 | return "", nil |
| 139 | } |
| 140 | if box == nil { |
| 141 | return "", ErrImportTokenKeyNeeded |
| 142 | } |
| 143 | pt, err := box.Open(row.TokenCiphertext, row.TokenNonce) |
| 144 | if err != nil { |
| 145 | return "", fmt.Errorf("github import: decrypt token: %w", err) |
| 146 | } |
| 147 | return string(pt), nil |
| 148 | } |
| 149 | |
| 150 | type GitHubClient struct { |
| 151 | HTTPClient *http.Client |
| 152 | BaseURL string |
| 153 | UserAgent string |
| 154 | } |
| 155 | |
| 156 | type GitHubRepo struct { |
| 157 | ID int64 |
| 158 | Name string |
| 159 | FullName string |
| 160 | CloneURL string |
| 161 | Description string |
| 162 | DefaultBranch string |
| 163 | Private bool |
| 164 | Fork bool |
| 165 | } |
| 166 | |
| 167 | func (c GitHubClient) ListOrgRepos(ctx context.Context, org, token string) ([]GitHubRepo, error) { |
| 168 | org, err := NormalizeGitHubOrg(org) |
| 169 | if err != nil { |
| 170 | return nil, err |
| 171 | } |
| 172 | base := strings.TrimRight(c.BaseURL, "/") |
| 173 | if base == "" { |
| 174 | base = "https://api.github.com" |
| 175 | } |
| 176 | client := c.HTTPClient |
| 177 | if client == nil { |
| 178 | client = &http.Client{Timeout: 30 * time.Second} |
| 179 | } |
| 180 | token = strings.TrimSpace(token) |
| 181 | repoType := "public" |
| 182 | if token != "" { |
| 183 | repoType = "all" |
| 184 | } |
| 185 | var out []GitHubRepo |
| 186 | for page := 1; page <= 100; page++ { |
| 187 | u, err := url.Parse(base + "/orgs/" + url.PathEscape(org) + "/repos") |
| 188 | if err != nil { |
| 189 | return nil, err |
| 190 | } |
| 191 | q := u.Query() |
| 192 | q.Set("type", repoType) |
| 193 | q.Set("per_page", "100") |
| 194 | q.Set("page", strconv.Itoa(page)) |
| 195 | q.Set("sort", "full_name") |
| 196 | u.RawQuery = q.Encode() |
| 197 | |
| 198 | req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) |
| 199 | if err != nil { |
| 200 | return nil, err |
| 201 | } |
| 202 | req.Header.Set("Accept", "application/vnd.github+json") |
| 203 | req.Header.Set("X-GitHub-Api-Version", "2022-11-28") |
| 204 | req.Header.Set("User-Agent", userAgent(c.UserAgent)) |
| 205 | if token != "" { |
| 206 | req.Header.Set("Authorization", "Bearer "+token) |
| 207 | } |
| 208 | resp, err := client.Do(req) |
| 209 | if err != nil { |
| 210 | return nil, err |
| 211 | } |
| 212 | repos, err := decodeGitHubRepos(resp) |
| 213 | if err != nil { |
| 214 | return nil, err |
| 215 | } |
| 216 | out = append(out, repos...) |
| 217 | if len(repos) < 100 { |
| 218 | return out, nil |
| 219 | } |
| 220 | } |
| 221 | return nil, fmt.Errorf("github import: too many repositories in %s", org) |
| 222 | } |
| 223 | |
| 224 | func decodeGitHubRepos(resp *http.Response) ([]GitHubRepo, error) { |
| 225 | defer resp.Body.Close() |
| 226 | body, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) |
| 227 | if err != nil { |
| 228 | return nil, err |
| 229 | } |
| 230 | if resp.StatusCode < 200 || resp.StatusCode >= 300 { |
| 231 | msg := strings.TrimSpace(string(body)) |
| 232 | if msg == "" { |
| 233 | msg = resp.Status |
| 234 | } |
| 235 | return nil, fmt.Errorf("github import: GitHub API returned %s: %s", resp.Status, msg) |
| 236 | } |
| 237 | var payload []struct { |
| 238 | ID int64 `json:"id"` |
| 239 | Name string `json:"name"` |
| 240 | FullName string `json:"full_name"` |
| 241 | CloneURL string `json:"clone_url"` |
| 242 | Description *string `json:"description"` |
| 243 | DefaultBranch string `json:"default_branch"` |
| 244 | Private bool `json:"private"` |
| 245 | Fork bool `json:"fork"` |
| 246 | } |
| 247 | dec := json.NewDecoder(bytes.NewReader(body)) |
| 248 | if err := dec.Decode(&payload); err != nil { |
| 249 | return nil, err |
| 250 | } |
| 251 | out := make([]GitHubRepo, 0, len(payload)) |
| 252 | for _, r := range payload { |
| 253 | desc := "" |
| 254 | if r.Description != nil { |
| 255 | desc = strings.TrimSpace(*r.Description) |
| 256 | } |
| 257 | out = append(out, GitHubRepo{ |
| 258 | ID: r.ID, |
| 259 | Name: r.Name, |
| 260 | FullName: r.FullName, |
| 261 | CloneURL: r.CloneURL, |
| 262 | Description: desc, |
| 263 | DefaultBranch: strings.TrimSpace(r.DefaultBranch), |
| 264 | Private: r.Private, |
| 265 | Fork: r.Fork, |
| 266 | }) |
| 267 | } |
| 268 | return out, nil |
| 269 | } |
| 270 | |
| 271 | func userAgent(custom string) string { |
| 272 | custom = strings.TrimSpace(custom) |
| 273 | if custom != "" { |
| 274 | return custom |
| 275 | } |
| 276 | return "shithub" |
| 277 | } |
| 278 | |
| 279 | func IsTerminalImportStatus(status string) bool { |
| 280 | return status == ImportStatusCompleted || status == ImportStatusFailed |
| 281 | } |
| 282 | |
| 283 | func IsTerminalImportRepoStatus(status string) bool { |
| 284 | return status == ImportRepoStatusImported || status == ImportRepoStatusSkipped || status == ImportRepoStatusFailed |
| 285 | } |
| 286 | |
| 287 | func IgnoreNoRows(err error) error { |
| 288 | if errors.Is(err, pgx.ErrNoRows) { |
| 289 | return nil |
| 290 | } |
| 291 | return err |
| 292 | } |
| 293 |