Go · 9919 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 // Package runner orchestrates the shithubd-runner claim/execute/status loop.
4 package runner
5
6 import (
7 "context"
8 "errors"
9 "fmt"
10 "io"
11 "log/slog"
12 "sync"
13 "time"
14
15 "github.com/tenseleyFlow/shithub/internal/runner/api"
16 "github.com/tenseleyFlow/shithub/internal/runner/engine"
17 )
18
19 type API interface {
20 Heartbeat(ctx context.Context, req api.HeartbeatRequest) (*api.Claim, error)
21 UpdateStatus(ctx context.Context, jobID int64, token string, req api.StatusRequest) (api.StatusResponse, error)
22 UpdateStepStatus(ctx context.Context, jobID, stepID int64, token string, req api.StatusRequest) (api.StepStatusResponse, error)
23 AppendLog(ctx context.Context, jobID int64, token string, req api.LogRequest) (api.LogResponse, error)
24 }
25
26 type Workspaces interface {
27 Prepare(runID, jobID int64) (string, error)
28 Remove(runID, jobID int64) error
29 }
30
31 type SleepFunc func(ctx context.Context, d time.Duration) error
32
33 type Options struct {
34 API API
35 Engine engine.Engine
36 Workspaces Workspaces
37 Logger *slog.Logger
38 Labels []string
39 Capacity int
40 PollInterval time.Duration
41 DefaultImage string
42 Clock func() time.Time
43 Sleep SleepFunc
44 }
45
46 type Runner struct {
47 api API
48 engine engine.Engine
49 workspaces Workspaces
50 logger *slog.Logger
51 labels []string
52 capacity int
53 pollInterval time.Duration
54 defaultImage string
55 clock func() time.Time
56 sleep SleepFunc
57 }
58
59 func New(opts Options) *Runner {
60 logger := opts.Logger
61 if logger == nil {
62 logger = slog.New(slog.NewTextHandler(io.Discard, nil))
63 }
64 clock := opts.Clock
65 if clock == nil {
66 clock = func() time.Time { return time.Now().UTC() }
67 }
68 sleep := opts.Sleep
69 if sleep == nil {
70 sleep = defaultSleep
71 }
72 poll := opts.PollInterval
73 if poll <= 0 {
74 poll = 5 * time.Second
75 }
76 capacity := opts.Capacity
77 if capacity <= 0 {
78 capacity = 1
79 }
80 return &Runner{
81 api: opts.API,
82 engine: opts.Engine,
83 workspaces: opts.Workspaces,
84 logger: logger,
85 labels: append([]string{}, opts.Labels...),
86 capacity: capacity,
87 pollInterval: poll,
88 defaultImage: opts.DefaultImage,
89 clock: clock,
90 sleep: sleep,
91 }
92 }
93
94 func (r *Runner) Run(ctx context.Context) error {
95 for {
96 claimed, err := r.RunOnce(ctx)
97 if err != nil {
98 if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
99 return err
100 }
101 r.logger.ErrorContext(ctx, "runner loop iteration failed", "error", err)
102 }
103 if claimed {
104 continue
105 }
106 if err := r.sleep(ctx, r.pollInterval); err != nil {
107 return err
108 }
109 }
110 }
111
112 func (r *Runner) RunOnce(ctx context.Context) (bool, error) {
113 claim, err := r.api.Heartbeat(ctx, api.HeartbeatRequest{Labels: r.labels, Capacity: r.capacity})
114 if err != nil {
115 return false, err
116 }
117 if claim == nil {
118 return false, nil
119 }
120 session := newJobSession(r.api, claim.Job.ID, claim.Token)
121 started := r.clock()
122 workspaceDir, err := r.workspaces.Prepare(claim.Job.RunID, claim.Job.ID)
123 if err != nil {
124 statusErr := r.complete(ctx, session, engine.ConclusionFailure, started, r.clock())
125 return true, errors.Join(fmt.Errorf("prepare workspace: %w", err), statusErr)
126 }
127 defer func() {
128 if err := r.workspaces.Remove(claim.Job.RunID, claim.Job.ID); err != nil {
129 r.logger.WarnContext(ctx, "workspace cleanup failed", "run_id", claim.Job.RunID, "job_id", claim.Job.ID, "error", err)
130 }
131 }()
132
133 running, err := session.UpdateStatus(ctx, api.StatusRequest{
134 Status: "running",
135 StartedAt: started,
136 })
137 if err != nil {
138 return true, fmt.Errorf("mark job running: %w", err)
139 }
140 if running.NextToken == "" {
141 return true, errors.New("mark job running: server did not return next_token")
142 }
143 var (
144 streamedEvents bool
145 drainErr chan error
146 )
147 if streamer, ok := r.engine.(engine.EventStreamer); ok {
148 events, err := streamer.StreamEvents(ctx, claim.Job.ID)
149 if err != nil {
150 return true, fmt.Errorf("open event stream: %w", err)
151 }
152 streamedEvents = true
153 drainErr = make(chan error, 1)
154 go func() {
155 drainErr <- drainEvents(ctx, session, events)
156 }()
157 } else {
158 logs, err := r.engine.StreamLogs(ctx, claim.Job.ID)
159 if err != nil {
160 return true, fmt.Errorf("open log stream: %w", err)
161 }
162 drainErr = make(chan error, 1)
163 go func() {
164 drainErr <- drainLogs(ctx, session, logs)
165 }()
166 }
167
168 outcome, execErr := r.engine.Execute(ctx, toEngineJob(claim.Job, workspaceDir, r.defaultImage))
169 if err := <-drainErr; err != nil {
170 return true, fmt.Errorf("stream runner events: %w", err)
171 }
172 if !streamedEvents {
173 for _, step := range outcome.StepOutcomes {
174 if step.StepID == 0 {
175 continue
176 }
177 if err := session.UpdateStepStatus(ctx, step.StepID, api.StatusRequest{
178 Status: step.Status,
179 Conclusion: step.Conclusion,
180 StartedAt: step.StartedAt,
181 CompletedAt: step.CompletedAt,
182 }); err != nil {
183 return true, fmt.Errorf("mark step completed: %w", err)
184 }
185 }
186 }
187 conclusion := outcome.Conclusion
188 if conclusion == "" {
189 conclusion = engine.ConclusionFailure
190 }
191 completed := outcome.CompletedAt
192 if completed.IsZero() {
193 completed = r.clock()
194 }
195 if outcome.StartedAt.IsZero() {
196 outcome.StartedAt = started
197 }
198 if err := r.complete(ctx, session, conclusion, outcome.StartedAt, completed); err != nil {
199 return true, err
200 }
201 if execErr != nil {
202 r.logger.WarnContext(ctx, "job completed with failing engine outcome", "job_id", claim.Job.ID, "conclusion", conclusion, "error", execErr)
203 }
204 return true, nil
205 }
206
207 func (r *Runner) complete(ctx context.Context, session *jobSession, conclusion string, started, completed time.Time) error {
208 _, err := session.UpdateStatus(ctx, api.StatusRequest{
209 Status: "completed",
210 Conclusion: conclusion,
211 StartedAt: started,
212 CompletedAt: completed,
213 })
214 if err != nil {
215 return fmt.Errorf("mark job completed: %w", err)
216 }
217 return nil
218 }
219
220 type jobSession struct {
221 api API
222 jobID int64
223 token string
224 mu sync.Mutex
225 }
226
227 func newJobSession(api API, jobID int64, token string) *jobSession {
228 return &jobSession{api: api, jobID: jobID, token: token}
229 }
230
231 func (s *jobSession) UpdateStatus(ctx context.Context, req api.StatusRequest) (api.StatusResponse, error) {
232 s.mu.Lock()
233 defer s.mu.Unlock()
234 resp, err := s.api.UpdateStatus(ctx, s.jobID, s.token, req)
235 if err != nil {
236 return resp, err
237 }
238 if resp.NextToken != "" {
239 s.token = resp.NextToken
240 }
241 return resp, nil
242 }
243
244 func (s *jobSession) UpdateStepStatus(ctx context.Context, stepID int64, req api.StatusRequest) error {
245 s.mu.Lock()
246 defer s.mu.Unlock()
247 resp, err := s.api.UpdateStepStatus(ctx, s.jobID, stepID, s.token, req)
248 if err != nil {
249 return err
250 }
251 if resp.NextToken != "" {
252 s.token = resp.NextToken
253 }
254 return nil
255 }
256
257 func (s *jobSession) AppendLog(ctx context.Context, chunk engine.LogChunk) error {
258 if len(chunk.Chunk) == 0 {
259 return nil
260 }
261 s.mu.Lock()
262 defer s.mu.Unlock()
263 resp, err := s.api.AppendLog(ctx, s.jobID, s.token, api.LogRequest{
264 Seq: chunk.Seq,
265 Chunk: chunk.Chunk,
266 StepID: chunk.StepID,
267 })
268 if err != nil {
269 return err
270 }
271 if resp.NextToken != "" {
272 s.token = resp.NextToken
273 }
274 return nil
275 }
276
277 func drainLogs(ctx context.Context, session *jobSession, logs <-chan engine.LogChunk) error {
278 for {
279 select {
280 case <-ctx.Done():
281 return ctx.Err()
282 case chunk, ok := <-logs:
283 if !ok {
284 return nil
285 }
286 if err := session.AppendLog(ctx, chunk); err != nil {
287 return err
288 }
289 }
290 }
291 }
292
293 func drainEvents(ctx context.Context, session *jobSession, events <-chan engine.Event) error {
294 for {
295 select {
296 case <-ctx.Done():
297 return ctx.Err()
298 case event, ok := <-events:
299 if !ok {
300 return nil
301 }
302 if event.Log != nil {
303 if err := session.AppendLog(ctx, *event.Log); err != nil {
304 return err
305 }
306 }
307 if event.Step != nil && event.Step.StepID != 0 {
308 if err := session.UpdateStepStatus(ctx, event.Step.StepID, api.StatusRequest{
309 Status: event.Step.Status,
310 Conclusion: event.Step.Conclusion,
311 StartedAt: event.Step.StartedAt,
312 CompletedAt: event.Step.CompletedAt,
313 }); err != nil {
314 return err
315 }
316 }
317 }
318 }
319 }
320
321 func toEngineJob(job api.Job, workspaceDir, defaultImage string) engine.Job {
322 steps := make([]engine.Step, 0, len(job.Steps))
323 for _, step := range job.Steps {
324 steps = append(steps, engine.Step{
325 ID: step.ID,
326 Index: step.Index,
327 StepID: step.StepID,
328 Name: step.Name,
329 If: step.If,
330 Run: step.Run,
331 Uses: step.Uses,
332 WorkingDirectory: step.WorkingDirectory,
333 Env: step.Env,
334 With: step.With,
335 ContinueOnError: step.ContinueOnError,
336 })
337 }
338 return engine.Job{
339 ID: job.ID,
340 RunID: job.RunID,
341 RepoID: job.RepoID,
342 RunIndex: job.RunIndex,
343 WorkflowFile: job.WorkflowFile,
344 WorkflowName: job.WorkflowName,
345 HeadSHA: job.HeadSHA,
346 HeadRef: job.HeadRef,
347 Event: job.Event,
348 EventPayload: job.EventPayload,
349 JobKey: job.JobKey,
350 JobName: job.JobName,
351 RunsOn: job.RunsOn,
352 Needs: append([]string{}, job.Needs...),
353 If: job.If,
354 TimeoutMinutes: job.TimeoutMinutes,
355 Permissions: job.Permissions,
356 Secrets: cloneStringMap(job.Secrets),
357 Env: job.Env,
358 Steps: steps,
359 WorkspaceDir: workspaceDir,
360 Image: defaultImage,
361 MaskValues: append([]string{}, job.MaskValues...),
362 }
363 }
364
365 func cloneStringMap(in map[string]string) map[string]string {
366 if len(in) == 0 {
367 return nil
368 }
369 out := make(map[string]string, len(in))
370 for k, v := range in {
371 out[k] = v
372 }
373 return out
374 }
375
376 func defaultSleep(ctx context.Context, d time.Duration) error {
377 timer := time.NewTimer(d)
378 defer timer.Stop()
379 select {
380 case <-ctx.Done():
381 return ctx.Err()
382 case <-timer.C:
383 return nil
384 }
385 }
386