| 1 | // SPDX-License-Identifier: AGPL-3.0-or-later |
| 2 | |
| 3 | // Package runner orchestrates the shithubd-runner claim/execute/status loop. |
| 4 | package runner |
| 5 | |
| 6 | import ( |
| 7 | "context" |
| 8 | "errors" |
| 9 | "fmt" |
| 10 | "io" |
| 11 | "log/slog" |
| 12 | "sync" |
| 13 | "sync/atomic" |
| 14 | "time" |
| 15 | |
| 16 | "github.com/tenseleyFlow/shithub/internal/runner/api" |
| 17 | "github.com/tenseleyFlow/shithub/internal/runner/engine" |
| 18 | ) |
| 19 | |
| 20 | type API interface { |
| 21 | Heartbeat(ctx context.Context, req api.HeartbeatRequest) (*api.Claim, error) |
| 22 | UpdateStatus(ctx context.Context, jobID int64, token string, req api.StatusRequest) (api.StatusResponse, error) |
| 23 | UpdateStepStatus(ctx context.Context, jobID, stepID int64, token string, req api.StatusRequest) (api.StepStatusResponse, error) |
| 24 | AppendLog(ctx context.Context, jobID int64, token string, req api.LogRequest) (api.LogResponse, error) |
| 25 | CancelCheck(ctx context.Context, jobID int64, token string) (api.CancelCheckResponse, error) |
| 26 | } |
| 27 | |
| 28 | type Workspaces interface { |
| 29 | Prepare(runID, jobID int64) (string, error) |
| 30 | Remove(runID, jobID int64) error |
| 31 | } |
| 32 | |
| 33 | type SleepFunc func(ctx context.Context, d time.Duration) error |
| 34 | |
| 35 | type Options struct { |
| 36 | API API |
| 37 | Engine engine.Engine |
| 38 | Workspaces Workspaces |
| 39 | Logger *slog.Logger |
| 40 | Labels []string |
| 41 | Capacity int |
| 42 | PollInterval time.Duration |
| 43 | CancelPollInterval time.Duration |
| 44 | DefaultImage string |
| 45 | Clock func() time.Time |
| 46 | Sleep SleepFunc |
| 47 | } |
| 48 | |
| 49 | type Runner struct { |
| 50 | api API |
| 51 | engine engine.Engine |
| 52 | workspaces Workspaces |
| 53 | logger *slog.Logger |
| 54 | labels []string |
| 55 | capacity int |
| 56 | pollInterval time.Duration |
| 57 | cancelPollInterval time.Duration |
| 58 | defaultImage string |
| 59 | clock func() time.Time |
| 60 | sleep SleepFunc |
| 61 | } |
| 62 | |
| 63 | func New(opts Options) *Runner { |
| 64 | logger := opts.Logger |
| 65 | if logger == nil { |
| 66 | logger = slog.New(slog.NewTextHandler(io.Discard, nil)) |
| 67 | } |
| 68 | clock := opts.Clock |
| 69 | if clock == nil { |
| 70 | clock = func() time.Time { return time.Now().UTC() } |
| 71 | } |
| 72 | sleep := opts.Sleep |
| 73 | if sleep == nil { |
| 74 | sleep = defaultSleep |
| 75 | } |
| 76 | poll := opts.PollInterval |
| 77 | if poll <= 0 { |
| 78 | poll = 5 * time.Second |
| 79 | } |
| 80 | cancelPoll := opts.CancelPollInterval |
| 81 | if cancelPoll <= 0 { |
| 82 | cancelPoll = 2 * time.Second |
| 83 | } |
| 84 | capacity := opts.Capacity |
| 85 | if capacity <= 0 { |
| 86 | capacity = 1 |
| 87 | } |
| 88 | return &Runner{ |
| 89 | api: opts.API, |
| 90 | engine: opts.Engine, |
| 91 | workspaces: opts.Workspaces, |
| 92 | logger: logger, |
| 93 | labels: append([]string{}, opts.Labels...), |
| 94 | capacity: capacity, |
| 95 | pollInterval: poll, |
| 96 | cancelPollInterval: cancelPoll, |
| 97 | defaultImage: opts.DefaultImage, |
| 98 | clock: clock, |
| 99 | sleep: sleep, |
| 100 | } |
| 101 | } |
| 102 | |
| 103 | func (r *Runner) Run(ctx context.Context) error { |
| 104 | for { |
| 105 | claimed, err := r.RunOnce(ctx) |
| 106 | if err != nil { |
| 107 | if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { |
| 108 | return err |
| 109 | } |
| 110 | r.logger.ErrorContext(ctx, "runner loop iteration failed", "error", err) |
| 111 | } |
| 112 | if claimed { |
| 113 | continue |
| 114 | } |
| 115 | if err := r.sleep(ctx, r.pollInterval); err != nil { |
| 116 | return err |
| 117 | } |
| 118 | } |
| 119 | } |
| 120 | |
| 121 | func (r *Runner) RunOnce(ctx context.Context) (bool, error) { |
| 122 | claim, err := r.api.Heartbeat(ctx, api.HeartbeatRequest{Labels: r.labels, Capacity: r.capacity}) |
| 123 | if err != nil { |
| 124 | return false, err |
| 125 | } |
| 126 | if claim == nil { |
| 127 | return false, nil |
| 128 | } |
| 129 | session := newJobSession(r.api, claim.Job.ID, claim.Token) |
| 130 | started := r.clock() |
| 131 | workspaceDir, err := r.workspaces.Prepare(claim.Job.RunID, claim.Job.ID) |
| 132 | if err != nil { |
| 133 | statusErr := r.complete(ctx, session, engine.ConclusionFailure, started, r.clock()) |
| 134 | return true, errors.Join(fmt.Errorf("prepare workspace: %w", err), statusErr) |
| 135 | } |
| 136 | defer func() { |
| 137 | if err := r.workspaces.Remove(claim.Job.RunID, claim.Job.ID); err != nil { |
| 138 | r.logger.WarnContext(ctx, "workspace cleanup failed", "run_id", claim.Job.RunID, "job_id", claim.Job.ID, "error", err) |
| 139 | } |
| 140 | }() |
| 141 | |
| 142 | running, err := session.UpdateStatus(ctx, api.StatusRequest{ |
| 143 | Status: "running", |
| 144 | StartedAt: started, |
| 145 | }) |
| 146 | if err != nil { |
| 147 | return true, fmt.Errorf("mark job running: %w", err) |
| 148 | } |
| 149 | if running.NextToken == "" { |
| 150 | return true, errors.New("mark job running: server did not return next_token") |
| 151 | } |
| 152 | var ( |
| 153 | streamedEvents bool |
| 154 | drainErr chan error |
| 155 | ) |
| 156 | if streamer, ok := r.engine.(engine.EventStreamer); ok { |
| 157 | events, err := streamer.StreamEvents(ctx, claim.Job.ID) |
| 158 | if err != nil { |
| 159 | return true, fmt.Errorf("open event stream: %w", err) |
| 160 | } |
| 161 | streamedEvents = true |
| 162 | drainErr = make(chan error, 1) |
| 163 | go func() { |
| 164 | drainErr <- drainEvents(ctx, session, events) |
| 165 | }() |
| 166 | } else { |
| 167 | logs, err := r.engine.StreamLogs(ctx, claim.Job.ID) |
| 168 | if err != nil { |
| 169 | return true, fmt.Errorf("open log stream: %w", err) |
| 170 | } |
| 171 | drainErr = make(chan error, 1) |
| 172 | go func() { |
| 173 | drainErr <- drainLogs(ctx, session, logs) |
| 174 | }() |
| 175 | } |
| 176 | |
| 177 | execCtx, execCancel := context.WithCancel(ctx) |
| 178 | watchCtx, stopCancelWatch := context.WithCancel(ctx) |
| 179 | cancelRequested := atomic.Bool{} |
| 180 | cancelWatchErr := make(chan error, 1) |
| 181 | go func() { |
| 182 | cancelWatchErr <- r.watchCancel(watchCtx, session, claim.Job.ID, execCancel, &cancelRequested) |
| 183 | }() |
| 184 | |
| 185 | outcome, execErr := r.engine.Execute(execCtx, toEngineJob(claim.Job, workspaceDir, r.defaultImage)) |
| 186 | stopCancelWatch() |
| 187 | if err := <-cancelWatchErr; err != nil { |
| 188 | return true, fmt.Errorf("watch job cancellation: %w", err) |
| 189 | } |
| 190 | if err := <-drainErr; err != nil { |
| 191 | return true, fmt.Errorf("stream runner events: %w", err) |
| 192 | } |
| 193 | if !streamedEvents { |
| 194 | for _, step := range outcome.StepOutcomes { |
| 195 | if step.StepID == 0 { |
| 196 | continue |
| 197 | } |
| 198 | if err := session.UpdateStepStatus(ctx, step.StepID, api.StatusRequest{ |
| 199 | Status: step.Status, |
| 200 | Conclusion: step.Conclusion, |
| 201 | StartedAt: step.StartedAt, |
| 202 | CompletedAt: step.CompletedAt, |
| 203 | }); err != nil { |
| 204 | return true, fmt.Errorf("mark step completed: %w", err) |
| 205 | } |
| 206 | } |
| 207 | } |
| 208 | conclusion := outcome.Conclusion |
| 209 | if conclusion == "" { |
| 210 | conclusion = engine.ConclusionFailure |
| 211 | } |
| 212 | finalStatus := "completed" |
| 213 | if cancelRequested.Load() { |
| 214 | finalStatus = "cancelled" |
| 215 | conclusion = engine.ConclusionCancelled |
| 216 | } |
| 217 | completed := outcome.CompletedAt |
| 218 | if completed.IsZero() { |
| 219 | completed = r.clock() |
| 220 | } |
| 221 | if outcome.StartedAt.IsZero() { |
| 222 | outcome.StartedAt = started |
| 223 | } |
| 224 | if err := r.finish(ctx, session, finalStatus, conclusion, outcome.StartedAt, completed); err != nil { |
| 225 | return true, err |
| 226 | } |
| 227 | if execErr != nil && !cancelRequested.Load() { |
| 228 | r.logger.WarnContext(ctx, "job completed with failing engine outcome", "job_id", claim.Job.ID, "conclusion", conclusion, "error", execErr) |
| 229 | } |
| 230 | return true, nil |
| 231 | } |
| 232 | |
| 233 | func (r *Runner) complete(ctx context.Context, session *jobSession, conclusion string, started, completed time.Time) error { |
| 234 | return r.finish(ctx, session, "completed", conclusion, started, completed) |
| 235 | } |
| 236 | |
| 237 | func (r *Runner) finish(ctx context.Context, session *jobSession, status, conclusion string, started, completed time.Time) error { |
| 238 | _, err := session.UpdateStatus(ctx, api.StatusRequest{ |
| 239 | Status: status, |
| 240 | Conclusion: conclusion, |
| 241 | StartedAt: started, |
| 242 | CompletedAt: completed, |
| 243 | }) |
| 244 | if err != nil { |
| 245 | return fmt.Errorf("mark job %s: %w", status, err) |
| 246 | } |
| 247 | return nil |
| 248 | } |
| 249 | |
| 250 | func (r *Runner) watchCancel( |
| 251 | ctx context.Context, |
| 252 | session *jobSession, |
| 253 | jobID int64, |
| 254 | execCancel context.CancelFunc, |
| 255 | cancelRequested *atomic.Bool, |
| 256 | ) error { |
| 257 | for { |
| 258 | if err := r.sleep(ctx, r.cancelPollInterval); err != nil { |
| 259 | if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { |
| 260 | return nil |
| 261 | } |
| 262 | return err |
| 263 | } |
| 264 | resp, err := session.CancelCheck(ctx) |
| 265 | if err != nil { |
| 266 | if ctx.Err() != nil { |
| 267 | return nil |
| 268 | } |
| 269 | r.logger.WarnContext(ctx, "runner cancel check failed", "job_id", jobID, "error", err) |
| 270 | continue |
| 271 | } |
| 272 | if !resp.Cancelled { |
| 273 | continue |
| 274 | } |
| 275 | cancelRequested.Store(true) |
| 276 | killCtx, cancel := context.WithTimeout(ctx, 5*time.Second) |
| 277 | if err := r.engine.Cancel(killCtx, jobID); err != nil && !errors.Is(err, engine.ErrUnsupported) { |
| 278 | r.logger.WarnContext(ctx, "runner engine cancel failed", "job_id", jobID, "error", err) |
| 279 | } |
| 280 | cancel() |
| 281 | execCancel() |
| 282 | return nil |
| 283 | } |
| 284 | } |
| 285 | |
| 286 | type jobSession struct { |
| 287 | api API |
| 288 | jobID int64 |
| 289 | token string |
| 290 | mu sync.Mutex |
| 291 | } |
| 292 | |
| 293 | func newJobSession(api API, jobID int64, token string) *jobSession { |
| 294 | return &jobSession{api: api, jobID: jobID, token: token} |
| 295 | } |
| 296 | |
| 297 | func (s *jobSession) UpdateStatus(ctx context.Context, req api.StatusRequest) (api.StatusResponse, error) { |
| 298 | s.mu.Lock() |
| 299 | defer s.mu.Unlock() |
| 300 | resp, err := s.api.UpdateStatus(ctx, s.jobID, s.token, req) |
| 301 | if err != nil { |
| 302 | return resp, err |
| 303 | } |
| 304 | if resp.NextToken != "" { |
| 305 | s.token = resp.NextToken |
| 306 | } |
| 307 | return resp, nil |
| 308 | } |
| 309 | |
| 310 | func (s *jobSession) UpdateStepStatus(ctx context.Context, stepID int64, req api.StatusRequest) error { |
| 311 | s.mu.Lock() |
| 312 | defer s.mu.Unlock() |
| 313 | resp, err := s.api.UpdateStepStatus(ctx, s.jobID, stepID, s.token, req) |
| 314 | if err != nil { |
| 315 | return err |
| 316 | } |
| 317 | if resp.NextToken != "" { |
| 318 | s.token = resp.NextToken |
| 319 | } |
| 320 | return nil |
| 321 | } |
| 322 | |
| 323 | func (s *jobSession) AppendLog(ctx context.Context, chunk engine.LogChunk) error { |
| 324 | if len(chunk.Chunk) == 0 { |
| 325 | return nil |
| 326 | } |
| 327 | s.mu.Lock() |
| 328 | defer s.mu.Unlock() |
| 329 | resp, err := s.api.AppendLog(ctx, s.jobID, s.token, api.LogRequest{ |
| 330 | Seq: chunk.Seq, |
| 331 | Chunk: chunk.Chunk, |
| 332 | StepID: chunk.StepID, |
| 333 | }) |
| 334 | if err != nil { |
| 335 | return err |
| 336 | } |
| 337 | if resp.NextToken != "" { |
| 338 | s.token = resp.NextToken |
| 339 | } |
| 340 | return nil |
| 341 | } |
| 342 | |
| 343 | func (s *jobSession) CancelCheck(ctx context.Context) (api.CancelCheckResponse, error) { |
| 344 | s.mu.Lock() |
| 345 | defer s.mu.Unlock() |
| 346 | resp, err := s.api.CancelCheck(ctx, s.jobID, s.token) |
| 347 | if err != nil { |
| 348 | return resp, err |
| 349 | } |
| 350 | if resp.NextToken != "" { |
| 351 | s.token = resp.NextToken |
| 352 | } |
| 353 | return resp, nil |
| 354 | } |
| 355 | |
| 356 | func drainLogs(ctx context.Context, session *jobSession, logs <-chan engine.LogChunk) error { |
| 357 | for { |
| 358 | select { |
| 359 | case <-ctx.Done(): |
| 360 | return ctx.Err() |
| 361 | case chunk, ok := <-logs: |
| 362 | if !ok { |
| 363 | return nil |
| 364 | } |
| 365 | if err := session.AppendLog(ctx, chunk); err != nil { |
| 366 | return err |
| 367 | } |
| 368 | } |
| 369 | } |
| 370 | } |
| 371 | |
| 372 | func drainEvents(ctx context.Context, session *jobSession, events <-chan engine.Event) error { |
| 373 | for { |
| 374 | select { |
| 375 | case <-ctx.Done(): |
| 376 | return ctx.Err() |
| 377 | case event, ok := <-events: |
| 378 | if !ok { |
| 379 | return nil |
| 380 | } |
| 381 | if event.Log != nil { |
| 382 | if err := session.AppendLog(ctx, *event.Log); err != nil { |
| 383 | return err |
| 384 | } |
| 385 | } |
| 386 | if event.Step != nil && event.Step.StepID != 0 { |
| 387 | if err := session.UpdateStepStatus(ctx, event.Step.StepID, api.StatusRequest{ |
| 388 | Status: event.Step.Status, |
| 389 | Conclusion: event.Step.Conclusion, |
| 390 | StartedAt: event.Step.StartedAt, |
| 391 | CompletedAt: event.Step.CompletedAt, |
| 392 | }); err != nil { |
| 393 | return err |
| 394 | } |
| 395 | } |
| 396 | } |
| 397 | } |
| 398 | } |
| 399 | |
| 400 | func toEngineJob(job api.Job, workspaceDir, defaultImage string) engine.Job { |
| 401 | steps := make([]engine.Step, 0, len(job.Steps)) |
| 402 | for _, step := range job.Steps { |
| 403 | steps = append(steps, engine.Step{ |
| 404 | ID: step.ID, |
| 405 | Index: step.Index, |
| 406 | StepID: step.StepID, |
| 407 | Name: step.Name, |
| 408 | If: step.If, |
| 409 | Run: step.Run, |
| 410 | Uses: step.Uses, |
| 411 | WorkingDirectory: step.WorkingDirectory, |
| 412 | Env: step.Env, |
| 413 | With: step.With, |
| 414 | ContinueOnError: step.ContinueOnError, |
| 415 | }) |
| 416 | } |
| 417 | return engine.Job{ |
| 418 | ID: job.ID, |
| 419 | RunID: job.RunID, |
| 420 | RepoID: job.RepoID, |
| 421 | RunIndex: job.RunIndex, |
| 422 | WorkflowFile: job.WorkflowFile, |
| 423 | WorkflowName: job.WorkflowName, |
| 424 | HeadSHA: job.HeadSHA, |
| 425 | HeadRef: job.HeadRef, |
| 426 | Event: job.Event, |
| 427 | EventPayload: job.EventPayload, |
| 428 | JobKey: job.JobKey, |
| 429 | JobName: job.JobName, |
| 430 | RunsOn: job.RunsOn, |
| 431 | Needs: append([]string{}, job.Needs...), |
| 432 | If: job.If, |
| 433 | TimeoutMinutes: job.TimeoutMinutes, |
| 434 | Permissions: job.Permissions, |
| 435 | Secrets: cloneStringMap(job.Secrets), |
| 436 | Env: job.Env, |
| 437 | Steps: steps, |
| 438 | WorkspaceDir: workspaceDir, |
| 439 | Image: defaultImage, |
| 440 | MaskValues: append([]string{}, job.MaskValues...), |
| 441 | } |
| 442 | } |
| 443 | |
| 444 | func cloneStringMap(in map[string]string) map[string]string { |
| 445 | if len(in) == 0 { |
| 446 | return nil |
| 447 | } |
| 448 | out := make(map[string]string, len(in)) |
| 449 | for k, v := range in { |
| 450 | out[k] = v |
| 451 | } |
| 452 | return out |
| 453 | } |
| 454 | |
| 455 | func defaultSleep(ctx context.Context, d time.Duration) error { |
| 456 | timer := time.NewTimer(d) |
| 457 | defer timer.Stop() |
| 458 | select { |
| 459 | case <-ctx.Done(): |
| 460 | return ctx.Err() |
| 461 | case <-timer.C: |
| 462 | return nil |
| 463 | } |
| 464 | } |
| 465 |