Go · 12425 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 // Package runner orchestrates the shithubd-runner claim/execute/status loop.
4 package runner
5
6 import (
7 "context"
8 "errors"
9 "fmt"
10 "io"
11 "log/slog"
12 "sync"
13 "sync/atomic"
14 "time"
15
16 "github.com/tenseleyFlow/shithub/internal/runner/api"
17 "github.com/tenseleyFlow/shithub/internal/runner/engine"
18 )
19
20 type API interface {
21 Heartbeat(ctx context.Context, req api.HeartbeatRequest) (*api.Claim, error)
22 UpdateStatus(ctx context.Context, jobID int64, token string, req api.StatusRequest) (api.StatusResponse, error)
23 UpdateStepStatus(ctx context.Context, jobID, stepID int64, token string, req api.StatusRequest) (api.StepStatusResponse, error)
24 AppendLog(ctx context.Context, jobID int64, token string, req api.LogRequest) (api.LogResponse, error)
25 CancelCheck(ctx context.Context, jobID int64, token string) (api.CancelCheckResponse, error)
26 }
27
28 type Workspaces interface {
29 Prepare(runID, jobID int64) (string, error)
30 Remove(runID, jobID int64) error
31 }
32
33 type SleepFunc func(ctx context.Context, d time.Duration) error
34
35 type Options struct {
36 API API
37 Engine engine.Engine
38 Workspaces Workspaces
39 Logger *slog.Logger
40 Labels []string
41 Capacity int
42 PollInterval time.Duration
43 CancelPollInterval time.Duration
44 DefaultImage string
45 Clock func() time.Time
46 Sleep SleepFunc
47 }
48
49 type Runner struct {
50 api API
51 engine engine.Engine
52 workspaces Workspaces
53 logger *slog.Logger
54 labels []string
55 capacity int
56 pollInterval time.Duration
57 cancelPollInterval time.Duration
58 defaultImage string
59 clock func() time.Time
60 sleep SleepFunc
61 }
62
63 func New(opts Options) *Runner {
64 logger := opts.Logger
65 if logger == nil {
66 logger = slog.New(slog.NewTextHandler(io.Discard, nil))
67 }
68 clock := opts.Clock
69 if clock == nil {
70 clock = func() time.Time { return time.Now().UTC() }
71 }
72 sleep := opts.Sleep
73 if sleep == nil {
74 sleep = defaultSleep
75 }
76 poll := opts.PollInterval
77 if poll <= 0 {
78 poll = 5 * time.Second
79 }
80 cancelPoll := opts.CancelPollInterval
81 if cancelPoll <= 0 {
82 cancelPoll = 2 * time.Second
83 }
84 capacity := opts.Capacity
85 if capacity <= 0 {
86 capacity = 1
87 }
88 return &Runner{
89 api: opts.API,
90 engine: opts.Engine,
91 workspaces: opts.Workspaces,
92 logger: logger,
93 labels: append([]string{}, opts.Labels...),
94 capacity: capacity,
95 pollInterval: poll,
96 cancelPollInterval: cancelPoll,
97 defaultImage: opts.DefaultImage,
98 clock: clock,
99 sleep: sleep,
100 }
101 }
102
103 func (r *Runner) Run(ctx context.Context) error {
104 for {
105 claimed, err := r.RunOnce(ctx)
106 if err != nil {
107 if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
108 return err
109 }
110 r.logger.ErrorContext(ctx, "runner loop iteration failed", "error", err)
111 }
112 if claimed {
113 continue
114 }
115 if err := r.sleep(ctx, r.pollInterval); err != nil {
116 return err
117 }
118 }
119 }
120
121 func (r *Runner) RunOnce(ctx context.Context) (bool, error) {
122 claim, err := r.api.Heartbeat(ctx, api.HeartbeatRequest{Labels: r.labels, Capacity: r.capacity})
123 if err != nil {
124 return false, err
125 }
126 if claim == nil {
127 return false, nil
128 }
129 session := newJobSession(r.api, claim.Job.ID, claim.Token)
130 started := r.clock()
131 workspaceDir, err := r.workspaces.Prepare(claim.Job.RunID, claim.Job.ID)
132 if err != nil {
133 statusErr := r.complete(ctx, session, engine.ConclusionFailure, started, r.clock())
134 return true, errors.Join(fmt.Errorf("prepare workspace: %w", err), statusErr)
135 }
136 defer func() {
137 if err := r.workspaces.Remove(claim.Job.RunID, claim.Job.ID); err != nil {
138 r.logger.WarnContext(ctx, "workspace cleanup failed", "run_id", claim.Job.RunID, "job_id", claim.Job.ID, "error", err)
139 }
140 }()
141
142 running, err := session.UpdateStatus(ctx, api.StatusRequest{
143 Status: "running",
144 StartedAt: started,
145 })
146 if err != nil {
147 return true, fmt.Errorf("mark job running: %w", err)
148 }
149 if running.NextToken == "" {
150 return true, errors.New("mark job running: server did not return next_token")
151 }
152 var (
153 streamedEvents bool
154 drainErr chan error
155 )
156 if streamer, ok := r.engine.(engine.EventStreamer); ok {
157 events, err := streamer.StreamEvents(ctx, claim.Job.ID)
158 if err != nil {
159 return true, fmt.Errorf("open event stream: %w", err)
160 }
161 streamedEvents = true
162 drainErr = make(chan error, 1)
163 go func() {
164 drainErr <- drainEvents(ctx, session, events)
165 }()
166 } else {
167 logs, err := r.engine.StreamLogs(ctx, claim.Job.ID)
168 if err != nil {
169 return true, fmt.Errorf("open log stream: %w", err)
170 }
171 drainErr = make(chan error, 1)
172 go func() {
173 drainErr <- drainLogs(ctx, session, logs)
174 }()
175 }
176
177 execCtx, execCancel := context.WithCancel(ctx)
178 watchCtx, stopCancelWatch := context.WithCancel(ctx)
179 cancelRequested := atomic.Bool{}
180 cancelWatchErr := make(chan error, 1)
181 go func() {
182 cancelWatchErr <- r.watchCancel(watchCtx, session, claim.Job.ID, execCancel, &cancelRequested)
183 }()
184
185 outcome, execErr := r.engine.Execute(execCtx, toEngineJob(claim.Job, workspaceDir, r.defaultImage))
186 stopCancelWatch()
187 if err := <-cancelWatchErr; err != nil {
188 return true, fmt.Errorf("watch job cancellation: %w", err)
189 }
190 if err := <-drainErr; err != nil {
191 return true, fmt.Errorf("stream runner events: %w", err)
192 }
193 if !streamedEvents {
194 for _, step := range outcome.StepOutcomes {
195 if step.StepID == 0 {
196 continue
197 }
198 if err := session.UpdateStepStatus(ctx, step.StepID, api.StatusRequest{
199 Status: step.Status,
200 Conclusion: step.Conclusion,
201 StartedAt: step.StartedAt,
202 CompletedAt: step.CompletedAt,
203 }); err != nil {
204 return true, fmt.Errorf("mark step completed: %w", err)
205 }
206 }
207 }
208 conclusion := outcome.Conclusion
209 if conclusion == "" {
210 conclusion = engine.ConclusionFailure
211 }
212 finalStatus := "completed"
213 if cancelRequested.Load() {
214 finalStatus = "cancelled"
215 conclusion = engine.ConclusionCancelled
216 }
217 completed := outcome.CompletedAt
218 if completed.IsZero() {
219 completed = r.clock()
220 }
221 if outcome.StartedAt.IsZero() {
222 outcome.StartedAt = started
223 }
224 if err := r.finish(ctx, session, finalStatus, conclusion, outcome.StartedAt, completed); err != nil {
225 return true, err
226 }
227 if execErr != nil && !cancelRequested.Load() {
228 r.logger.WarnContext(ctx, "job completed with failing engine outcome", "job_id", claim.Job.ID, "conclusion", conclusion, "error", execErr)
229 }
230 return true, nil
231 }
232
233 func (r *Runner) complete(ctx context.Context, session *jobSession, conclusion string, started, completed time.Time) error {
234 return r.finish(ctx, session, "completed", conclusion, started, completed)
235 }
236
237 func (r *Runner) finish(ctx context.Context, session *jobSession, status, conclusion string, started, completed time.Time) error {
238 _, err := session.UpdateStatus(ctx, api.StatusRequest{
239 Status: status,
240 Conclusion: conclusion,
241 StartedAt: started,
242 CompletedAt: completed,
243 })
244 if err != nil {
245 return fmt.Errorf("mark job %s: %w", status, err)
246 }
247 return nil
248 }
249
250 func (r *Runner) watchCancel(
251 ctx context.Context,
252 session *jobSession,
253 jobID int64,
254 execCancel context.CancelFunc,
255 cancelRequested *atomic.Bool,
256 ) error {
257 for {
258 if err := r.sleep(ctx, r.cancelPollInterval); err != nil {
259 if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
260 return nil
261 }
262 return err
263 }
264 resp, err := session.CancelCheck(ctx)
265 if err != nil {
266 if ctx.Err() != nil {
267 return nil
268 }
269 r.logger.WarnContext(ctx, "runner cancel check failed", "job_id", jobID, "error", err)
270 continue
271 }
272 if !resp.Cancelled {
273 continue
274 }
275 cancelRequested.Store(true)
276 killCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
277 if err := r.engine.Cancel(killCtx, jobID); err != nil && !errors.Is(err, engine.ErrUnsupported) {
278 r.logger.WarnContext(ctx, "runner engine cancel failed", "job_id", jobID, "error", err)
279 }
280 cancel()
281 execCancel()
282 return nil
283 }
284 }
285
286 type jobSession struct {
287 api API
288 jobID int64
289 token string
290 mu sync.Mutex
291 }
292
293 func newJobSession(api API, jobID int64, token string) *jobSession {
294 return &jobSession{api: api, jobID: jobID, token: token}
295 }
296
297 func (s *jobSession) UpdateStatus(ctx context.Context, req api.StatusRequest) (api.StatusResponse, error) {
298 s.mu.Lock()
299 defer s.mu.Unlock()
300 resp, err := s.api.UpdateStatus(ctx, s.jobID, s.token, req)
301 if err != nil {
302 return resp, err
303 }
304 if resp.NextToken != "" {
305 s.token = resp.NextToken
306 }
307 return resp, nil
308 }
309
310 func (s *jobSession) UpdateStepStatus(ctx context.Context, stepID int64, req api.StatusRequest) error {
311 s.mu.Lock()
312 defer s.mu.Unlock()
313 resp, err := s.api.UpdateStepStatus(ctx, s.jobID, stepID, s.token, req)
314 if err != nil {
315 return err
316 }
317 if resp.NextToken != "" {
318 s.token = resp.NextToken
319 }
320 return nil
321 }
322
323 func (s *jobSession) AppendLog(ctx context.Context, chunk engine.LogChunk) error {
324 if len(chunk.Chunk) == 0 {
325 return nil
326 }
327 s.mu.Lock()
328 defer s.mu.Unlock()
329 resp, err := s.api.AppendLog(ctx, s.jobID, s.token, api.LogRequest{
330 Seq: chunk.Seq,
331 Chunk: chunk.Chunk,
332 StepID: chunk.StepID,
333 })
334 if err != nil {
335 return err
336 }
337 if resp.NextToken != "" {
338 s.token = resp.NextToken
339 }
340 return nil
341 }
342
343 func (s *jobSession) CancelCheck(ctx context.Context) (api.CancelCheckResponse, error) {
344 s.mu.Lock()
345 defer s.mu.Unlock()
346 resp, err := s.api.CancelCheck(ctx, s.jobID, s.token)
347 if err != nil {
348 return resp, err
349 }
350 if resp.NextToken != "" {
351 s.token = resp.NextToken
352 }
353 return resp, nil
354 }
355
356 func drainLogs(ctx context.Context, session *jobSession, logs <-chan engine.LogChunk) error {
357 for {
358 select {
359 case <-ctx.Done():
360 return ctx.Err()
361 case chunk, ok := <-logs:
362 if !ok {
363 return nil
364 }
365 if err := session.AppendLog(ctx, chunk); err != nil {
366 return err
367 }
368 }
369 }
370 }
371
372 func drainEvents(ctx context.Context, session *jobSession, events <-chan engine.Event) error {
373 for {
374 select {
375 case <-ctx.Done():
376 return ctx.Err()
377 case event, ok := <-events:
378 if !ok {
379 return nil
380 }
381 if event.Log != nil {
382 if err := session.AppendLog(ctx, *event.Log); err != nil {
383 return err
384 }
385 }
386 if event.Step != nil && event.Step.StepID != 0 {
387 if err := session.UpdateStepStatus(ctx, event.Step.StepID, api.StatusRequest{
388 Status: event.Step.Status,
389 Conclusion: event.Step.Conclusion,
390 StartedAt: event.Step.StartedAt,
391 CompletedAt: event.Step.CompletedAt,
392 }); err != nil {
393 return err
394 }
395 }
396 }
397 }
398 }
399
400 func toEngineJob(job api.Job, workspaceDir, defaultImage string) engine.Job {
401 steps := make([]engine.Step, 0, len(job.Steps))
402 for _, step := range job.Steps {
403 steps = append(steps, engine.Step{
404 ID: step.ID,
405 Index: step.Index,
406 StepID: step.StepID,
407 Name: step.Name,
408 If: step.If,
409 Run: step.Run,
410 Uses: step.Uses,
411 WorkingDirectory: step.WorkingDirectory,
412 Env: step.Env,
413 With: step.With,
414 ContinueOnError: step.ContinueOnError,
415 })
416 }
417 return engine.Job{
418 ID: job.ID,
419 RunID: job.RunID,
420 RepoID: job.RepoID,
421 RunIndex: job.RunIndex,
422 WorkflowFile: job.WorkflowFile,
423 WorkflowName: job.WorkflowName,
424 HeadSHA: job.HeadSHA,
425 HeadRef: job.HeadRef,
426 Event: job.Event,
427 EventPayload: job.EventPayload,
428 JobKey: job.JobKey,
429 JobName: job.JobName,
430 RunsOn: job.RunsOn,
431 Needs: append([]string{}, job.Needs...),
432 If: job.If,
433 TimeoutMinutes: job.TimeoutMinutes,
434 Permissions: job.Permissions,
435 Secrets: cloneStringMap(job.Secrets),
436 Env: job.Env,
437 Steps: steps,
438 WorkspaceDir: workspaceDir,
439 Image: defaultImage,
440 MaskValues: append([]string{}, job.MaskValues...),
441 }
442 }
443
444 func cloneStringMap(in map[string]string) map[string]string {
445 if len(in) == 0 {
446 return nil
447 }
448 out := make(map[string]string, len(in))
449 for k, v := range in {
450 out[k] = v
451 }
452 return out
453 }
454
455 func defaultSleep(ctx context.Context, d time.Duration) error {
456 timer := time.NewTimer(d)
457 defer timer.Stop()
458 select {
459 case <-ctx.Done():
460 return ctx.Err()
461 case <-timer.C:
462 return nil
463 }
464 }
465