Go · 12735 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 // Package runner orchestrates the shithubd-runner claim/execute/status loop.
4 package runner
5
6 import (
7 "context"
8 "errors"
9 "fmt"
10 "io"
11 "log/slog"
12 "sync"
13 "sync/atomic"
14 "time"
15
16 "github.com/tenseleyFlow/shithub/internal/runner/api"
17 "github.com/tenseleyFlow/shithub/internal/runner/engine"
18 )
19
20 type API interface {
21 Heartbeat(ctx context.Context, req api.HeartbeatRequest) (*api.Claim, error)
22 UpdateStatus(ctx context.Context, jobID int64, token string, req api.StatusRequest) (api.StatusResponse, error)
23 UpdateStepStatus(ctx context.Context, jobID, stepID int64, token string, req api.StatusRequest) (api.StepStatusResponse, error)
24 AppendLog(ctx context.Context, jobID int64, token string, req api.LogRequest) (api.LogResponse, error)
25 CancelCheck(ctx context.Context, jobID int64, token string) (api.CancelCheckResponse, error)
26 }
27
28 type Workspaces interface {
29 Prepare(runID, jobID int64) (string, error)
30 Remove(runID, jobID int64) error
31 }
32
33 type SleepFunc func(ctx context.Context, d time.Duration) error
34
35 type Options struct {
36 API API
37 Engine engine.Engine
38 Workspaces Workspaces
39 Logger *slog.Logger
40 Labels []string
41 Capacity int
42 HostName string
43 Version string
44 PollInterval time.Duration
45 CancelPollInterval time.Duration
46 DefaultImage string
47 Clock func() time.Time
48 Sleep SleepFunc
49 }
50
51 type Runner struct {
52 api API
53 engine engine.Engine
54 workspaces Workspaces
55 logger *slog.Logger
56 labels []string
57 capacity int
58 hostName string
59 version string
60 pollInterval time.Duration
61 cancelPollInterval time.Duration
62 defaultImage string
63 clock func() time.Time
64 sleep SleepFunc
65 }
66
67 func New(opts Options) *Runner {
68 logger := opts.Logger
69 if logger == nil {
70 logger = slog.New(slog.NewTextHandler(io.Discard, nil))
71 }
72 clock := opts.Clock
73 if clock == nil {
74 clock = func() time.Time { return time.Now().UTC() }
75 }
76 sleep := opts.Sleep
77 if sleep == nil {
78 sleep = defaultSleep
79 }
80 poll := opts.PollInterval
81 if poll <= 0 {
82 poll = 5 * time.Second
83 }
84 cancelPoll := opts.CancelPollInterval
85 if cancelPoll <= 0 {
86 cancelPoll = 2 * time.Second
87 }
88 capacity := opts.Capacity
89 if capacity <= 0 {
90 capacity = 1
91 }
92 return &Runner{
93 api: opts.API,
94 engine: opts.Engine,
95 workspaces: opts.Workspaces,
96 logger: logger,
97 labels: append([]string{}, opts.Labels...),
98 capacity: capacity,
99 hostName: opts.HostName,
100 version: opts.Version,
101 pollInterval: poll,
102 cancelPollInterval: cancelPoll,
103 defaultImage: opts.DefaultImage,
104 clock: clock,
105 sleep: sleep,
106 }
107 }
108
109 func (r *Runner) Run(ctx context.Context) error {
110 for {
111 claimed, err := r.RunOnce(ctx)
112 if err != nil {
113 if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
114 return err
115 }
116 r.logger.ErrorContext(ctx, "runner loop iteration failed", "error", err)
117 }
118 if claimed {
119 continue
120 }
121 if err := r.sleep(ctx, r.pollInterval); err != nil {
122 return err
123 }
124 }
125 }
126
127 func (r *Runner) RunOnce(ctx context.Context) (bool, error) {
128 claim, err := r.api.Heartbeat(ctx, api.HeartbeatRequest{
129 Labels: r.labels,
130 Capacity: r.capacity,
131 HostName: r.hostName,
132 Version: r.version,
133 })
134 if err != nil {
135 return false, err
136 }
137 if claim == nil {
138 return false, nil
139 }
140 session := newJobSession(r.api, claim.Job.ID, claim.Token)
141 started := r.clock()
142 workspaceDir, err := r.workspaces.Prepare(claim.Job.RunID, claim.Job.ID)
143 if err != nil {
144 statusErr := r.complete(ctx, session, engine.ConclusionFailure, started, r.clock())
145 return true, errors.Join(fmt.Errorf("prepare workspace: %w", err), statusErr)
146 }
147 defer func() {
148 if err := r.workspaces.Remove(claim.Job.RunID, claim.Job.ID); err != nil {
149 r.logger.WarnContext(ctx, "workspace cleanup failed", "run_id", claim.Job.RunID, "job_id", claim.Job.ID, "error", err)
150 }
151 }()
152
153 running, err := session.UpdateStatus(ctx, api.StatusRequest{
154 Status: "running",
155 StartedAt: started,
156 })
157 if err != nil {
158 return true, fmt.Errorf("mark job running: %w", err)
159 }
160 if running.NextToken == "" {
161 return true, errors.New("mark job running: server did not return next_token")
162 }
163 var (
164 streamedEvents bool
165 drainErr chan error
166 )
167 if streamer, ok := r.engine.(engine.EventStreamer); ok {
168 events, err := streamer.StreamEvents(ctx, claim.Job.ID)
169 if err != nil {
170 return true, fmt.Errorf("open event stream: %w", err)
171 }
172 streamedEvents = true
173 drainErr = make(chan error, 1)
174 go func() {
175 drainErr <- drainEvents(ctx, session, events)
176 }()
177 } else {
178 logs, err := r.engine.StreamLogs(ctx, claim.Job.ID)
179 if err != nil {
180 return true, fmt.Errorf("open log stream: %w", err)
181 }
182 drainErr = make(chan error, 1)
183 go func() {
184 drainErr <- drainLogs(ctx, session, logs)
185 }()
186 }
187
188 execCtx, execCancel := context.WithCancel(ctx)
189 watchCtx, stopCancelWatch := context.WithCancel(ctx)
190 cancelRequested := atomic.Bool{}
191 cancelWatchErr := make(chan error, 1)
192 go func() {
193 cancelWatchErr <- r.watchCancel(watchCtx, session, claim.Job.ID, execCancel, &cancelRequested)
194 }()
195
196 outcome, execErr := r.engine.Execute(execCtx, toEngineJob(claim.Job, workspaceDir, r.defaultImage))
197 stopCancelWatch()
198 if err := <-cancelWatchErr; err != nil {
199 return true, fmt.Errorf("watch job cancellation: %w", err)
200 }
201 if err := <-drainErr; err != nil {
202 return true, fmt.Errorf("stream runner events: %w", err)
203 }
204 if !streamedEvents {
205 for _, step := range outcome.StepOutcomes {
206 if step.StepID == 0 {
207 continue
208 }
209 if err := session.UpdateStepStatus(ctx, step.StepID, api.StatusRequest{
210 Status: step.Status,
211 Conclusion: step.Conclusion,
212 StartedAt: step.StartedAt,
213 CompletedAt: step.CompletedAt,
214 }); err != nil {
215 return true, fmt.Errorf("mark step completed: %w", err)
216 }
217 }
218 }
219 conclusion := outcome.Conclusion
220 if conclusion == "" {
221 conclusion = engine.ConclusionFailure
222 }
223 finalStatus := "completed"
224 if cancelRequested.Load() {
225 finalStatus = "cancelled"
226 conclusion = engine.ConclusionCancelled
227 }
228 completed := outcome.CompletedAt
229 if completed.IsZero() {
230 completed = r.clock()
231 }
232 if outcome.StartedAt.IsZero() {
233 outcome.StartedAt = started
234 }
235 if err := r.finish(ctx, session, finalStatus, conclusion, outcome.StartedAt, completed); err != nil {
236 return true, err
237 }
238 if execErr != nil && !cancelRequested.Load() {
239 r.logger.WarnContext(ctx, "job completed with failing engine outcome", "job_id", claim.Job.ID, "conclusion", conclusion, "error", execErr)
240 }
241 return true, nil
242 }
243
244 func (r *Runner) complete(ctx context.Context, session *jobSession, conclusion string, started, completed time.Time) error {
245 return r.finish(ctx, session, "completed", conclusion, started, completed)
246 }
247
248 func (r *Runner) finish(ctx context.Context, session *jobSession, status, conclusion string, started, completed time.Time) error {
249 _, err := session.UpdateStatus(ctx, api.StatusRequest{
250 Status: status,
251 Conclusion: conclusion,
252 StartedAt: started,
253 CompletedAt: completed,
254 })
255 if err != nil {
256 return fmt.Errorf("mark job %s: %w", status, err)
257 }
258 return nil
259 }
260
261 func (r *Runner) watchCancel(
262 ctx context.Context,
263 session *jobSession,
264 jobID int64,
265 execCancel context.CancelFunc,
266 cancelRequested *atomic.Bool,
267 ) error {
268 for {
269 if err := r.sleep(ctx, r.cancelPollInterval); err != nil {
270 if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
271 return nil
272 }
273 return err
274 }
275 resp, err := session.CancelCheck(ctx)
276 if err != nil {
277 if ctx.Err() != nil {
278 return nil
279 }
280 r.logger.WarnContext(ctx, "runner cancel check failed", "job_id", jobID, "error", err)
281 continue
282 }
283 if !resp.Cancelled {
284 continue
285 }
286 cancelRequested.Store(true)
287 killCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
288 if err := r.engine.Cancel(killCtx, jobID); err != nil && !errors.Is(err, engine.ErrUnsupported) {
289 r.logger.WarnContext(ctx, "runner engine cancel failed", "job_id", jobID, "error", err)
290 }
291 cancel()
292 execCancel()
293 return nil
294 }
295 }
296
297 type jobSession struct {
298 api API
299 jobID int64
300 token string
301 mu sync.Mutex
302 }
303
304 func newJobSession(api API, jobID int64, token string) *jobSession {
305 return &jobSession{api: api, jobID: jobID, token: token}
306 }
307
308 func (s *jobSession) UpdateStatus(ctx context.Context, req api.StatusRequest) (api.StatusResponse, error) {
309 s.mu.Lock()
310 defer s.mu.Unlock()
311 resp, err := s.api.UpdateStatus(ctx, s.jobID, s.token, req)
312 if err != nil {
313 return resp, err
314 }
315 if resp.NextToken != "" {
316 s.token = resp.NextToken
317 }
318 return resp, nil
319 }
320
321 func (s *jobSession) UpdateStepStatus(ctx context.Context, stepID int64, req api.StatusRequest) error {
322 s.mu.Lock()
323 defer s.mu.Unlock()
324 resp, err := s.api.UpdateStepStatus(ctx, s.jobID, stepID, s.token, req)
325 if err != nil {
326 return err
327 }
328 if resp.NextToken != "" {
329 s.token = resp.NextToken
330 }
331 return nil
332 }
333
334 func (s *jobSession) AppendLog(ctx context.Context, chunk engine.LogChunk) error {
335 if len(chunk.Chunk) == 0 {
336 return nil
337 }
338 s.mu.Lock()
339 defer s.mu.Unlock()
340 resp, err := s.api.AppendLog(ctx, s.jobID, s.token, api.LogRequest{
341 Seq: chunk.Seq,
342 Chunk: chunk.Chunk,
343 StepID: chunk.StepID,
344 })
345 if err != nil {
346 return err
347 }
348 if resp.NextToken != "" {
349 s.token = resp.NextToken
350 }
351 return nil
352 }
353
354 func (s *jobSession) CancelCheck(ctx context.Context) (api.CancelCheckResponse, error) {
355 s.mu.Lock()
356 defer s.mu.Unlock()
357 resp, err := s.api.CancelCheck(ctx, s.jobID, s.token)
358 if err != nil {
359 return resp, err
360 }
361 if resp.NextToken != "" {
362 s.token = resp.NextToken
363 }
364 return resp, nil
365 }
366
367 func drainLogs(ctx context.Context, session *jobSession, logs <-chan engine.LogChunk) error {
368 for {
369 select {
370 case <-ctx.Done():
371 return ctx.Err()
372 case chunk, ok := <-logs:
373 if !ok {
374 return nil
375 }
376 if err := session.AppendLog(ctx, chunk); err != nil {
377 return err
378 }
379 }
380 }
381 }
382
383 func drainEvents(ctx context.Context, session *jobSession, events <-chan engine.Event) error {
384 for {
385 select {
386 case <-ctx.Done():
387 return ctx.Err()
388 case event, ok := <-events:
389 if !ok {
390 return nil
391 }
392 if event.Log != nil {
393 if err := session.AppendLog(ctx, *event.Log); err != nil {
394 return err
395 }
396 }
397 if event.Step != nil && event.Step.StepID != 0 {
398 if err := session.UpdateStepStatus(ctx, event.Step.StepID, api.StatusRequest{
399 Status: event.Step.Status,
400 Conclusion: event.Step.Conclusion,
401 StartedAt: event.Step.StartedAt,
402 CompletedAt: event.Step.CompletedAt,
403 }); err != nil {
404 return err
405 }
406 }
407 }
408 }
409 }
410
411 func toEngineJob(job api.Job, workspaceDir, defaultImage string) engine.Job {
412 steps := make([]engine.Step, 0, len(job.Steps))
413 for _, step := range job.Steps {
414 steps = append(steps, engine.Step{
415 ID: step.ID,
416 Index: step.Index,
417 StepID: step.StepID,
418 Name: step.Name,
419 If: step.If,
420 Run: step.Run,
421 Uses: step.Uses,
422 WorkingDirectory: step.WorkingDirectory,
423 Env: step.Env,
424 With: step.With,
425 ContinueOnError: step.ContinueOnError,
426 })
427 }
428 return engine.Job{
429 ID: job.ID,
430 RunID: job.RunID,
431 RepoID: job.RepoID,
432 RunIndex: job.RunIndex,
433 WorkflowFile: job.WorkflowFile,
434 WorkflowName: job.WorkflowName,
435 CheckoutURL: job.CheckoutURL,
436 CheckoutToken: job.CheckoutToken,
437 HeadSHA: job.HeadSHA,
438 HeadRef: job.HeadRef,
439 Event: job.Event,
440 EventPayload: job.EventPayload,
441 JobKey: job.JobKey,
442 JobName: job.JobName,
443 RunsOn: job.RunsOn,
444 Needs: append([]string{}, job.Needs...),
445 If: job.If,
446 TimeoutMinutes: job.TimeoutMinutes,
447 Permissions: job.Permissions,
448 Secrets: cloneStringMap(job.Secrets),
449 Env: job.Env,
450 Steps: steps,
451 WorkspaceDir: workspaceDir,
452 Image: defaultImage,
453 MaskValues: append([]string{}, job.MaskValues...),
454 }
455 }
456
457 func cloneStringMap(in map[string]string) map[string]string {
458 if len(in) == 0 {
459 return nil
460 }
461 out := make(map[string]string, len(in))
462 for k, v := range in {
463 out[k] = v
464 }
465 return out
466 }
467
468 func defaultSleep(ctx context.Context, d time.Duration) error {
469 timer := time.NewTimer(d)
470 defer timer.Stop()
471 select {
472 case <-ctx.Done():
473 return ctx.Err()
474 case <-timer.C:
475 return nil
476 }
477 }
478