Go · 4535 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 package repo
4
5 import (
6 "context"
7 "errors"
8 "strings"
9 "time"
10
11 "github.com/jackc/pgx/v5/pgtype"
12
13 "github.com/tenseleyFlow/shithub/internal/repos"
14 repogit "github.com/tenseleyFlow/shithub/internal/repos/git"
15 reposdb "github.com/tenseleyFlow/shithub/internal/repos/sqlc"
16 "github.com/tenseleyFlow/shithub/internal/worker"
17 )
18
19 const sourceRemoteFetchTimeout = 45 * time.Second
20
21 func (h *Handlers) saveRepoSourceRemote(ctx context.Context, repoID int64, rawURL string) (string, error) {
22 remoteURL, err := repos.ValidateSourceRemoteURL(ctx, rawURL)
23 if err != nil || remoteURL == "" {
24 return remoteURL, err
25 }
26 _, err = h.rq.UpsertRepoSourceRemote(ctx, h.d.Pool, reposdb.UpsertRepoSourceRemoteParams{
27 RepoID: repoID,
28 RemoteUrl: remoteURL,
29 })
30 return remoteURL, err
31 }
32
33 func (h *Handlers) fetchRepoSourceRemote(ctx context.Context, row reposdb.Repo, ownerSlug, remoteURL string) error {
34 remoteURL, err := repos.ValidateSourceRemoteURL(ctx, remoteURL)
35 if err != nil {
36 h.markRepoSourceRemoteFetchError(ctx, row.ID, err)
37 return err
38 }
39 gitDir, err := h.d.RepoFS.RepoPath(ownerSlug, row.Name)
40 if err != nil {
41 h.markRepoSourceRemoteFetchError(ctx, row.ID, err)
42 return err
43 }
44 fetchCtx, cancel := context.WithTimeout(ctx, sourceRemoteFetchTimeout)
45 defer cancel()
46 if err := repogit.FetchRemoteHeadsAndTags(fetchCtx, gitDir, remoteURL); err != nil {
47 h.markRepoSourceRemoteFetchError(ctx, row.ID, err)
48 return err
49 }
50 if err := h.refreshFetchedRepoState(ctx, row, gitDir); err != nil {
51 h.markRepoSourceRemoteFetchError(ctx, row.ID, err)
52 return err
53 }
54 if err := h.rq.MarkRepoSourceRemoteFetched(ctx, h.d.Pool, row.ID); err != nil && h.d.Logger != nil {
55 h.d.Logger.WarnContext(ctx, "source-remote: mark fetched", "error", err, "repo_id", row.ID)
56 }
57 return nil
58 }
59
60 func (h *Handlers) refreshFetchedRepoState(ctx context.Context, row reposdb.Repo, gitDir string) error {
61 refs, err := repogit.ListRefs(ctx, gitDir)
62 if err != nil {
63 return err
64 }
65 branch, oid := chooseFetchedDefaultBranch(row.DefaultBranch, refs.Branches)
66 if branch == "" {
67 return nil
68 }
69 if branch != row.DefaultBranch {
70 if err := h.rq.UpdateRepoDefaultBranch(ctx, h.d.Pool, reposdb.UpdateRepoDefaultBranchParams{
71 ID: row.ID,
72 DefaultBranch: branch,
73 }); err != nil {
74 return err
75 }
76 if err := repogit.SetSymbolicRef(ctx, gitDir, "HEAD", "refs/heads/"+branch); err != nil && h.d.Logger != nil {
77 h.d.Logger.WarnContext(ctx, "source-remote: set symbolic head", "error", err, "repo_id", row.ID, "branch", branch)
78 }
79 }
80 if !row.DefaultBranchOid.Valid || row.DefaultBranchOid.String != oid {
81 if err := h.rq.UpdateRepoDefaultBranchOID(ctx, h.d.Pool, reposdb.UpdateRepoDefaultBranchOIDParams{
82 ID: row.ID,
83 DefaultBranchOid: pgtype.Text{String: oid, Valid: true},
84 }); err != nil {
85 return err
86 }
87 if _, err := worker.Enqueue(ctx, h.d.Pool, worker.KindRepoIndexCode, map[string]any{"repo_id": row.ID}, worker.EnqueueOptions{}); err != nil && h.d.Logger != nil {
88 h.d.Logger.WarnContext(ctx, "source-remote: enqueue index", "error", err, "repo_id", row.ID)
89 }
90 }
91 if _, err := worker.Enqueue(ctx, h.d.Pool, worker.KindRepoSizeRecalc, map[string]any{"repo_id": row.ID}, worker.EnqueueOptions{}); err != nil && h.d.Logger != nil {
92 h.d.Logger.WarnContext(ctx, "source-remote: enqueue size", "error", err, "repo_id", row.ID)
93 }
94 _ = worker.Notify(ctx, h.d.Pool)
95 return nil
96 }
97
98 func chooseFetchedDefaultBranch(current string, branches []repogit.RefEntry) (name, oid string) {
99 if len(branches) == 0 {
100 return "", ""
101 }
102 for _, candidate := range []string{current, "trunk", "main", "master"} {
103 if candidate == "" {
104 continue
105 }
106 for _, branch := range branches {
107 if branch.Name == candidate {
108 return branch.Name, branch.OID
109 }
110 }
111 }
112 return branches[0].Name, branches[0].OID
113 }
114
115 func (h *Handlers) markRepoSourceRemoteFetchError(ctx context.Context, repoID int64, err error) {
116 if err == nil {
117 return
118 }
119 msg := strings.TrimSpace(err.Error())
120 if len(msg) > 500 {
121 msg = msg[:500]
122 }
123 if markErr := h.rq.MarkRepoSourceRemoteFetchError(ctx, h.d.Pool, reposdb.MarkRepoSourceRemoteFetchErrorParams{
124 RepoID: repoID,
125 LastError: pgtype.Text{String: msg, Valid: true},
126 }); markErr != nil && h.d.Logger != nil {
127 h.d.Logger.WarnContext(ctx, "source-remote: mark fetch error", "error", markErr, "cause", err, "repo_id", repoID)
128 }
129 }
130
131 func isInvalidSourceRemote(err error) bool {
132 return errors.Is(err, repos.ErrInvalidSourceRemote)
133 }
134