Go · 10250 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 package api
4
5 import (
6 "context"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "io"
11 "net/http"
12 "strings"
13 "time"
14
15 "github.com/go-chi/chi/v5"
16 "github.com/jackc/pgx/v5"
17
18 "github.com/tenseleyFlow/shithub/internal/actions/runnerlabels"
19 "github.com/tenseleyFlow/shithub/internal/actions/runnertoken"
20 actionsdb "github.com/tenseleyFlow/shithub/internal/actions/sqlc"
21 "github.com/tenseleyFlow/shithub/internal/auth/runnerjwt"
22 "github.com/tenseleyFlow/shithub/internal/ratelimit"
23 )
24
25 var runnerHeartbeatLimit = ratelimit.Policy{
26 Scope: "actions:runner_dispatch",
27 Max: 60,
28 Window: time.Minute,
29 }
30
31 func (h *Handlers) mountRunners(r chi.Router) {
32 r.Post("/api/v1/runners/heartbeat", h.runnerHeartbeat)
33 }
34
35 type runnerHeartbeatRequest struct {
36 Labels []string `json:"labels"`
37 Capacity int `json:"capacity"`
38 }
39
40 func (h *Handlers) runnerHeartbeat(w http.ResponseWriter, r *http.Request) {
41 if h.d.RunnerJWT == nil {
42 writeAPIError(w, http.StatusServiceUnavailable, "runner API is not configured")
43 return
44 }
45 runner, ok := h.authenticateRunner(w, r)
46 if !ok {
47 return
48 }
49 if !h.allowRunnerHeartbeat(w, r, runner.ID) {
50 return
51 }
52
53 var body runnerHeartbeatRequest
54 dec := json.NewDecoder(r.Body)
55 dec.DisallowUnknownFields()
56 if err := dec.Decode(&body); err != nil && !errors.Is(err, io.EOF) {
57 writeAPIError(w, http.StatusBadRequest, "invalid JSON: "+err.Error())
58 return
59 }
60 labels := runner.Labels
61 if body.Labels != nil {
62 var err error
63 labels, err = runnerlabels.Normalize(body.Labels)
64 if err != nil {
65 writeAPIError(w, http.StatusBadRequest, err.Error())
66 return
67 }
68 }
69 capacity := int(runner.Capacity)
70 if body.Capacity != 0 {
71 capacity = body.Capacity
72 }
73 if capacity < 1 || capacity > 64 {
74 writeAPIError(w, http.StatusBadRequest, "capacity must be between 1 and 64")
75 return
76 }
77
78 job, steps, claimed, err := h.claimRunnerJob(r.Context(), runner.ID, labels, int32(capacity))
79 if err != nil {
80 h.d.Logger.ErrorContext(r.Context(), "runner heartbeat claim failed", "runner_id", runner.ID, "error", err)
81 writeAPIError(w, http.StatusInternalServerError, "runner heartbeat failed")
82 return
83 }
84 if !claimed {
85 w.WriteHeader(http.StatusNoContent)
86 return
87 }
88
89 token, claims, err := h.d.RunnerJWT.Mint(runnerjwt.MintParams{
90 RunnerID: runner.ID,
91 JobID: job.ID,
92 RunID: job.RunID,
93 RepoID: job.RepoID,
94 })
95 if err != nil {
96 h.d.Logger.ErrorContext(r.Context(), "runner jwt mint failed", "runner_id", runner.ID, "job_id", job.ID, "error", err)
97 writeAPIError(w, http.StatusInternalServerError, "runner token mint failed")
98 return
99 }
100 writeJSON(w, http.StatusOK, presentRunnerClaim(job, steps, token, time.Unix(claims.Exp, 0)))
101 }
102
103 func (h *Handlers) authenticateRunner(w http.ResponseWriter, r *http.Request) (actionsdb.GetRunnerByTokenHashRow, bool) {
104 const prefix = "Bearer "
105 authz := r.Header.Get("Authorization")
106 if !strings.HasPrefix(authz, prefix) {
107 writeAPIError(w, http.StatusUnauthorized, "runner token required")
108 return actionsdb.GetRunnerByTokenHashRow{}, false
109 }
110 hash, err := runnertoken.HashOf(strings.TrimSpace(strings.TrimPrefix(authz, prefix)))
111 if err != nil {
112 writeAPIError(w, http.StatusUnauthorized, "runner token invalid")
113 return actionsdb.GetRunnerByTokenHashRow{}, false
114 }
115 runner, err := actionsdb.New().GetRunnerByTokenHash(r.Context(), h.d.Pool, hash)
116 if err != nil {
117 writeAPIError(w, http.StatusUnauthorized, "runner token invalid")
118 return actionsdb.GetRunnerByTokenHashRow{}, false
119 }
120 return runner, true
121 }
122
123 func (h *Handlers) allowRunnerHeartbeat(w http.ResponseWriter, r *http.Request, runnerID int64) bool {
124 if h.d.RateLimiter == nil {
125 return true
126 }
127 decision, err := h.d.RateLimiter.Allow(r.Context(), runnerHeartbeatLimit, fmt.Sprintf("runner:%d", runnerID))
128 if err != nil {
129 h.d.Logger.WarnContext(r.Context(), "runner heartbeat rate-limit failed", "runner_id", runnerID, "error", err)
130 }
131 ratelimit.StampHeaders(w, decision)
132 if !decision.Allowed {
133 w.Header().Set("Retry-After", fmt.Sprintf("%d", int(decision.RetryAfter/time.Second)))
134 writeAPIError(w, http.StatusTooManyRequests, "rate limit exceeded")
135 return false
136 }
137 return true
138 }
139
140 func (h *Handlers) claimRunnerJob(
141 ctx context.Context,
142 runnerID int64,
143 labels []string,
144 capacity int32,
145 ) (actionsdb.ClaimQueuedWorkflowJobRow, []actionsdb.ListRunnerStepsForJobRow, bool, error) {
146 q := actionsdb.New()
147 tx, err := h.d.Pool.Begin(ctx)
148 if err != nil {
149 return actionsdb.ClaimQueuedWorkflowJobRow{}, nil, false, err
150 }
151 committed := false
152 defer func() {
153 if !committed {
154 _ = tx.Rollback(ctx)
155 }
156 }()
157
158 if _, err := q.LockRunnerByID(ctx, tx, runnerID); err != nil {
159 return actionsdb.ClaimQueuedWorkflowJobRow{}, nil, false, err
160 }
161 running, err := q.CountRunningJobsForRunner(ctx, tx, runnerID)
162 if err != nil {
163 return actionsdb.ClaimQueuedWorkflowJobRow{}, nil, false, err
164 }
165 if running >= capacity {
166 if _, err := q.HeartbeatRunner(ctx, tx, actionsdb.HeartbeatRunnerParams{
167 ID: runnerID,
168 Labels: labels,
169 Capacity: capacity,
170 Status: actionsdb.WorkflowRunnerStatusBusy,
171 }); err != nil {
172 return actionsdb.ClaimQueuedWorkflowJobRow{}, nil, false, err
173 }
174 if err := tx.Commit(ctx); err != nil {
175 return actionsdb.ClaimQueuedWorkflowJobRow{}, nil, false, err
176 }
177 committed = true
178 return actionsdb.ClaimQueuedWorkflowJobRow{}, nil, false, nil
179 }
180
181 job, err := q.ClaimQueuedWorkflowJob(ctx, tx, actionsdb.ClaimQueuedWorkflowJobParams{
182 RunnerID: runnerID,
183 Labels: labels,
184 })
185 if err != nil {
186 if !errors.Is(err, pgx.ErrNoRows) {
187 return actionsdb.ClaimQueuedWorkflowJobRow{}, nil, false, err
188 }
189 if _, err := q.HeartbeatRunner(ctx, tx, actionsdb.HeartbeatRunnerParams{
190 ID: runnerID,
191 Labels: labels,
192 Capacity: capacity,
193 Status: actionsdb.WorkflowRunnerStatusIdle,
194 }); err != nil {
195 return actionsdb.ClaimQueuedWorkflowJobRow{}, nil, false, err
196 }
197 if err := tx.Commit(ctx); err != nil {
198 return actionsdb.ClaimQueuedWorkflowJobRow{}, nil, false, err
199 }
200 committed = true
201 return actionsdb.ClaimQueuedWorkflowJobRow{}, nil, false, nil
202 }
203 if err := q.MarkWorkflowRunRunning(ctx, tx, job.RunID); err != nil {
204 return actionsdb.ClaimQueuedWorkflowJobRow{}, nil, false, err
205 }
206 steps, err := q.ListRunnerStepsForJob(ctx, tx, job.ID)
207 if err != nil {
208 return actionsdb.ClaimQueuedWorkflowJobRow{}, nil, false, err
209 }
210 status := actionsdb.WorkflowRunnerStatusIdle
211 if running+1 >= capacity {
212 status = actionsdb.WorkflowRunnerStatusBusy
213 }
214 if _, err := q.HeartbeatRunner(ctx, tx, actionsdb.HeartbeatRunnerParams{
215 ID: runnerID,
216 Labels: labels,
217 Capacity: capacity,
218 Status: status,
219 }); err != nil {
220 return actionsdb.ClaimQueuedWorkflowJobRow{}, nil, false, err
221 }
222 if err := tx.Commit(ctx); err != nil {
223 return actionsdb.ClaimQueuedWorkflowJobRow{}, nil, false, err
224 }
225 committed = true
226 return job, steps, true, nil
227 }
228
229 type runnerClaimResponse struct {
230 Token string `json:"token"`
231 ExpiresAt string `json:"expires_at"`
232 Job runnerJobPayload `json:"job"`
233 }
234
235 type runnerJobPayload struct {
236 ID int64 `json:"id"`
237 RunID int64 `json:"run_id"`
238 RepoID int64 `json:"repo_id"`
239 RunIndex int64 `json:"run_index"`
240 WorkflowFile string `json:"workflow_file"`
241 WorkflowName string `json:"workflow_name"`
242 HeadSHA string `json:"head_sha"`
243 HeadRef string `json:"head_ref"`
244 Event string `json:"event"`
245 JobKey string `json:"job_key"`
246 JobName string `json:"job_name"`
247 RunsOn string `json:"runs_on"`
248 Needs []string `json:"needs"`
249 If string `json:"if"`
250 TimeoutMinutes int32 `json:"timeout_minutes"`
251 Permissions json.RawMessage `json:"permissions"`
252 Env json.RawMessage `json:"env"`
253 Steps []runnerStep `json:"steps"`
254 }
255
256 type runnerStep struct {
257 ID int64 `json:"id"`
258 Index int32 `json:"index"`
259 StepID string `json:"step_id"`
260 Name string `json:"name"`
261 If string `json:"if"`
262 Run string `json:"run"`
263 Uses string `json:"uses"`
264 WorkingDirectory string `json:"working_directory"`
265 Env json.RawMessage `json:"env"`
266 With json.RawMessage `json:"with"`
267 ContinueOnError bool `json:"continue_on_error"`
268 }
269
270 func presentRunnerClaim(
271 job actionsdb.ClaimQueuedWorkflowJobRow,
272 steps []actionsdb.ListRunnerStepsForJobRow,
273 token string,
274 expiresAt time.Time,
275 ) runnerClaimResponse {
276 outSteps := make([]runnerStep, 0, len(steps))
277 for _, step := range steps {
278 outSteps = append(outSteps, runnerStep{
279 ID: step.ID,
280 Index: step.StepIndex,
281 StepID: step.StepID,
282 Name: step.StepName,
283 If: step.IfExpr,
284 Run: step.RunCommand,
285 Uses: step.UsesAlias,
286 WorkingDirectory: step.WorkingDirectory,
287 Env: rawJSONOrObject(step.StepEnv),
288 With: rawJSONOrObject(step.StepWith),
289 ContinueOnError: step.ContinueOnError,
290 })
291 }
292 return runnerClaimResponse{
293 Token: token,
294 ExpiresAt: expiresAt.UTC().Format(time.RFC3339),
295 Job: runnerJobPayload{
296 ID: job.ID,
297 RunID: job.RunID,
298 RepoID: job.RepoID,
299 RunIndex: job.RunIndex,
300 WorkflowFile: job.WorkflowFile,
301 WorkflowName: job.WorkflowName,
302 HeadSHA: job.HeadSha,
303 HeadRef: job.HeadRef,
304 Event: string(job.Event),
305 JobKey: job.JobKey,
306 JobName: job.JobName,
307 RunsOn: job.RunsOn,
308 Needs: job.NeedsJobs,
309 If: job.IfExpr,
310 TimeoutMinutes: job.TimeoutMinutes,
311 Permissions: rawJSONOrObject(job.Permissions),
312 Env: rawJSONOrObject(job.JobEnv),
313 Steps: outSteps,
314 },
315 }
316 }
317
318 func rawJSONOrObject(b []byte) json.RawMessage {
319 if len(b) == 0 || !json.Valid(b) {
320 return json.RawMessage(`{}`)
321 }
322 return json.RawMessage(b)
323 }
324