Go · 4086 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 // Package concurrency owns workflow-level Actions concurrency groups.
4 package concurrency
5
6 import (
7 "context"
8 "errors"
9 "fmt"
10 "strings"
11 "unicode/utf8"
12
13 "github.com/tenseleyFlow/shithub/internal/actions/runstate"
14 actionsdb "github.com/tenseleyFlow/shithub/internal/actions/sqlc"
15 "github.com/tenseleyFlow/shithub/internal/actions/workflow"
16 )
17
18 const (
19 // MaxGroupChars mirrors workflow_runs.concurrency_group's CHECK
20 // constraint. Enforcing it before insert gives workflow authors a useful
21 // error instead of a generic constraint failure.
22 MaxGroupChars = 256
23
24 // CancelReason is the metrics label used when a newer run cancels older
25 // group occupants.
26 CancelReason = "concurrency"
27 )
28
29 // ResolveInput carries the trigger-time context available to concurrency.group
30 // expression evaluation. Secrets are intentionally not populated here.
31 type ResolveInput struct {
32 Workflow *workflow.Workflow
33 EventPayload map[string]any
34 HeadSHA string
35 HeadRef string
36 }
37
38 // Resolution is the trigger-time concurrency policy for one workflow run.
39 type Resolution struct {
40 Group string
41 CancelInProgress bool
42 }
43
44 // Resolve evaluates workflow.concurrency.group against the trigger context.
45 func Resolve(in ResolveInput) (Resolution, error) {
46 if in.Workflow == nil {
47 return Resolution{}, errors.New("actions concurrency: nil Workflow")
48 }
49 c := in.Workflow.Concurrency
50 raw := strings.TrimSpace(c.Group.Raw)
51 if raw == "" {
52 return Resolution{}, nil
53 }
54 group, err := ResolveGroup(raw, EvalContext{
55 EventPayload: in.EventPayload,
56 HeadSHA: in.HeadSHA,
57 HeadRef: in.HeadRef,
58 })
59 if err != nil {
60 return Resolution{}, err
61 }
62 if group == "" {
63 return Resolution{}, nil
64 }
65 return Resolution{Group: group, CancelInProgress: c.CancelInProgress}, nil
66 }
67
68 // EnforceParams identifies a newly enqueued run whose concurrency group should
69 // be checked against older active runs in the same repo.
70 type EnforceParams struct {
71 Run actionsdb.WorkflowRun
72 CancelInProgress bool
73 }
74
75 // EnforceResult summarizes what the slot manager observed and changed.
76 type EnforceResult struct {
77 BlockingRuns []actionsdb.WorkflowRun
78 CancelledJobs []actionsdb.WorkflowJob
79 }
80
81 // Enforce applies workflow-level concurrency rules. With cancel-in-progress it
82 // requests cancellation for older active occupants. Without it, this only locks
83 // and reports blockers; ClaimQueuedWorkflowJob keeps the newer run pending.
84 func Enforce(
85 ctx context.Context,
86 q *actionsdb.Queries,
87 db actionsdb.DBTX,
88 p EnforceParams,
89 ) (EnforceResult, error) {
90 if q == nil {
91 return EnforceResult{}, errors.New("actions concurrency: nil Queries")
92 }
93 if strings.TrimSpace(p.Run.ConcurrencyGroup) == "" {
94 return EnforceResult{}, nil
95 }
96 blockers, err := q.ListBlockingConcurrencyRunsForUpdate(ctx, db, actionsdb.ListBlockingConcurrencyRunsForUpdateParams{
97 RepoID: p.Run.RepoID,
98 ConcurrencyGroup: p.Run.ConcurrencyGroup,
99 RunID: p.Run.ID,
100 })
101 if err != nil {
102 return EnforceResult{}, err
103 }
104 out := EnforceResult{BlockingRuns: blockers}
105 if !p.CancelInProgress || len(blockers) == 0 {
106 return out, nil
107 }
108 for _, blocker := range blockers {
109 changed, err := q.RequestWorkflowRunCancel(ctx, db, blocker.ID)
110 if err != nil {
111 return EnforceResult{}, err
112 }
113 for _, job := range changed {
114 if job.Status == actionsdb.WorkflowJobStatusCancelled {
115 if _, err := q.CancelOpenWorkflowStepsForJob(ctx, db, job.ID); err != nil {
116 return EnforceResult{}, err
117 }
118 }
119 }
120 if len(changed) > 0 {
121 if _, _, err := runstate.RollupAfterCancel(ctx, q, db, blocker.ID); err != nil {
122 return EnforceResult{}, err
123 }
124 out.CancelledJobs = append(out.CancelledJobs, changed...)
125 }
126 }
127 return out, nil
128 }
129
130 func validateGroup(group string) (string, error) {
131 group = strings.TrimSpace(group)
132 if group == "" {
133 return "", nil
134 }
135 if utf8.RuneCountInString(group) > MaxGroupChars {
136 return "", fmt.Errorf("actions concurrency: group exceeds %d characters", MaxGroupChars)
137 }
138 return group, nil
139 }
140