Go · 7926 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 package repos
4
5 import (
6 "context"
7 "errors"
8 "fmt"
9 "log/slog"
10 "net/url"
11 "strings"
12 "time"
13
14 "github.com/jackc/pgx/v5/pgtype"
15 "github.com/jackc/pgx/v5/pgxpool"
16
17 "github.com/tenseleyFlow/shithub/internal/infra/storage"
18 repogit "github.com/tenseleyFlow/shithub/internal/repos/git"
19 reposdb "github.com/tenseleyFlow/shithub/internal/repos/sqlc"
20 "github.com/tenseleyFlow/shithub/internal/security/ssrf"
21 "github.com/tenseleyFlow/shithub/internal/worker"
22 )
23
24 const (
25 MaxSourceRemoteURLLen = 2048
26 SourceRemoteFetchTimeout = 45 * time.Second
27 )
28
29 var ErrInvalidSourceRemote = errors.New("repos: invalid source remote URL")
30
31 // NormalizeSourceRemoteURL validates and canonicalizes the public Git
32 // remote URL shithub is allowed to fetch from for source imports and
33 // submodule commit backfills. Credentials are deliberately not allowed
34 // here; private import credentials need a separate secret-backed design.
35 func NormalizeSourceRemoteURL(raw string) (string, error) {
36 raw = strings.TrimSpace(raw)
37 if raw == "" {
38 return "", nil
39 }
40 if len(raw) > MaxSourceRemoteURLLen {
41 return "", fmt.Errorf("%w: too long", ErrInvalidSourceRemote)
42 }
43 u, err := url.Parse(raw)
44 if err != nil {
45 return "", fmt.Errorf("%w: malformed URL", ErrInvalidSourceRemote)
46 }
47 switch strings.ToLower(u.Scheme) {
48 case "http", "https":
49 default:
50 return "", fmt.Errorf("%w: source imports currently support http(s) git remotes", ErrInvalidSourceRemote)
51 }
52 if u.Hostname() == "" {
53 return "", fmt.Errorf("%w: missing host", ErrInvalidSourceRemote)
54 }
55 if u.User != nil {
56 return "", fmt.Errorf("%w: credentials are not supported in source remote URLs", ErrInvalidSourceRemote)
57 }
58 if u.RawQuery != "" || u.Fragment != "" {
59 return "", fmt.Errorf("%w: query strings and fragments are not supported", ErrInvalidSourceRemote)
60 }
61 if strings.Trim(u.EscapedPath(), "/") == "" {
62 return "", fmt.Errorf("%w: missing repository path", ErrInvalidSourceRemote)
63 }
64 u.Scheme = strings.ToLower(u.Scheme)
65 u.Host = strings.ToLower(u.Host)
66 return u.String(), nil
67 }
68
69 // ValidateSourceRemoteURL runs the same SSRF defenses used for webhooks
70 // before a URL is persisted or fetched by git. Git still receives the URL
71 // as argv (never through a shell), and fetch disables submodule recursion.
72 func ValidateSourceRemoteURL(ctx context.Context, raw string) (string, error) {
73 normalized, err := NormalizeSourceRemoteURL(raw)
74 if err != nil || normalized == "" {
75 return normalized, err
76 }
77 if err := ssrf.Default().ValidateWithResolve(ctx, normalized); err != nil {
78 return "", fmt.Errorf("%w: %v", ErrInvalidSourceRemote, err)
79 }
80 return normalized, nil
81 }
82
83 // SourceRemoteDeps wires source-remote fetches. FetchToken is optional and
84 // only used for private GitHub imports; it is not stored in repo_source_remotes.
85 type SourceRemoteDeps struct {
86 Pool *pgxpool.Pool
87 RepoFS *storage.RepoFS
88 Logger *slog.Logger
89 FetchToken string
90 }
91
92 // SaveSourceRemote validates and persists the credential-free source remote.
93 func SaveSourceRemote(ctx context.Context, deps SourceRemoteDeps, repoID int64, rawURL string) (string, error) {
94 remoteURL, err := ValidateSourceRemoteURL(ctx, rawURL)
95 if err != nil || remoteURL == "" {
96 return remoteURL, err
97 }
98 _, err = reposdb.New().UpsertRepoSourceRemote(ctx, deps.Pool, reposdb.UpsertRepoSourceRemoteParams{
99 RepoID: repoID,
100 RemoteUrl: remoteURL,
101 })
102 return remoteURL, err
103 }
104
105 // FetchSourceRemote imports public heads/tags from a configured source remote
106 // and updates cached default-branch/index/size state.
107 func FetchSourceRemote(ctx context.Context, deps SourceRemoteDeps, row reposdb.Repo, ownerSlug, remoteURL string) error {
108 remoteURL, err := ValidateSourceRemoteURL(ctx, remoteURL)
109 if err != nil {
110 MarkSourceRemoteFetchError(ctx, deps, row.ID, err)
111 return err
112 }
113 gitDir, err := deps.RepoFS.RepoPath(ownerSlug, row.Name)
114 if err != nil {
115 MarkSourceRemoteFetchError(ctx, deps, row.ID, err)
116 return err
117 }
118 fetchCtx, cancel := context.WithTimeout(ctx, SourceRemoteFetchTimeout)
119 defer cancel()
120 if strings.TrimSpace(deps.FetchToken) != "" {
121 err = repogit.FetchRemoteHeadsAndTagsWithToken(fetchCtx, gitDir, remoteURL, deps.FetchToken)
122 } else {
123 err = repogit.FetchRemoteHeadsAndTags(fetchCtx, gitDir, remoteURL)
124 }
125 if err != nil {
126 MarkSourceRemoteFetchError(ctx, deps, row.ID, err)
127 return err
128 }
129 if err := RefreshFetchedRepoState(ctx, deps, row, gitDir); err != nil {
130 MarkSourceRemoteFetchError(ctx, deps, row.ID, err)
131 return err
132 }
133 q := reposdb.New()
134 if err := q.MarkRepoSourceRemoteFetched(ctx, deps.Pool, row.ID); err != nil && deps.Logger != nil {
135 deps.Logger.WarnContext(ctx, "source-remote: mark fetched", "error", err, "repo_id", row.ID)
136 }
137 return nil
138 }
139
140 // RefreshFetchedRepoState reconciles the repo row after a source fetch.
141 func RefreshFetchedRepoState(ctx context.Context, deps SourceRemoteDeps, row reposdb.Repo, gitDir string) error {
142 refs, err := repogit.ListRefs(ctx, gitDir)
143 if err != nil {
144 return err
145 }
146 branch, oid := ChooseFetchedDefaultBranch(row.DefaultBranch, refs.Branches)
147 if branch == "" {
148 return nil
149 }
150 q := reposdb.New()
151 if branch != row.DefaultBranch {
152 if err := q.UpdateRepoDefaultBranch(ctx, deps.Pool, reposdb.UpdateRepoDefaultBranchParams{
153 ID: row.ID,
154 DefaultBranch: branch,
155 }); err != nil {
156 return err
157 }
158 if err := repogit.SetSymbolicRef(ctx, gitDir, "HEAD", "refs/heads/"+branch); err != nil && deps.Logger != nil {
159 deps.Logger.WarnContext(ctx, "source-remote: set symbolic head", "error", err, "repo_id", row.ID, "branch", branch)
160 }
161 }
162 if !row.DefaultBranchOid.Valid || row.DefaultBranchOid.String != oid {
163 if err := q.UpdateRepoDefaultBranchOID(ctx, deps.Pool, reposdb.UpdateRepoDefaultBranchOIDParams{
164 ID: row.ID,
165 DefaultBranchOid: pgtype.Text{String: oid, Valid: true},
166 }); err != nil {
167 return err
168 }
169 if _, err := worker.Enqueue(ctx, deps.Pool, worker.KindRepoIndexCode, map[string]any{"repo_id": row.ID}, worker.EnqueueOptions{}); err != nil && deps.Logger != nil {
170 deps.Logger.WarnContext(ctx, "source-remote: enqueue index", "error", err, "repo_id", row.ID)
171 }
172 }
173 if _, err := worker.Enqueue(ctx, deps.Pool, worker.KindRepoSizeRecalc, map[string]any{"repo_id": row.ID}, worker.EnqueueOptions{}); err != nil && deps.Logger != nil {
174 deps.Logger.WarnContext(ctx, "source-remote: enqueue size", "error", err, "repo_id", row.ID)
175 }
176 _ = worker.Notify(ctx, deps.Pool)
177 return nil
178 }
179
180 // ChooseFetchedDefaultBranch mirrors GitHub import behavior: keep the current
181 // default if present, otherwise prefer trunk/main/master before falling back to
182 // the first fetched branch.
183 func ChooseFetchedDefaultBranch(current string, branches []repogit.RefEntry) (name, oid string) {
184 if len(branches) == 0 {
185 return "", ""
186 }
187 for _, candidate := range []string{current, "trunk", "main", "master"} {
188 if candidate == "" {
189 continue
190 }
191 for _, branch := range branches {
192 if branch.Name == candidate {
193 return branch.Name, branch.OID
194 }
195 }
196 }
197 return branches[0].Name, branches[0].OID
198 }
199
200 // MarkSourceRemoteFetchError stores the latest source-fetch failure without
201 // leaking credentials; callers pass credential-free remote URLs.
202 func MarkSourceRemoteFetchError(ctx context.Context, deps SourceRemoteDeps, repoID int64, err error) {
203 if err == nil {
204 return
205 }
206 msg := strings.TrimSpace(err.Error())
207 if len(msg) > 500 {
208 msg = msg[:500]
209 }
210 if markErr := reposdb.New().MarkRepoSourceRemoteFetchError(ctx, deps.Pool, reposdb.MarkRepoSourceRemoteFetchErrorParams{
211 RepoID: repoID,
212 LastError: pgtype.Text{String: msg, Valid: true},
213 }); markErr != nil && deps.Logger != nil {
214 deps.Logger.WarnContext(ctx, "source-remote: mark fetch error", "error", markErr, "cause", err, "repo_id", repoID)
215 }
216 }
217
218 func IsInvalidSourceRemote(err error) bool {
219 return errors.Is(err, ErrInvalidSourceRemote)
220 }
221