tenseleyflow/shithub / 6931e9a

Browse files

actions/runners: enforce drain and hard revoke (S41j-4)

Authored by mfwolffe <wolffemf@dukes.jmu.edu>
SHA
6931e9a8d205f67e026b68f321f91b936cb25937
Parents
feb5cd7
Tree
e193b26

7 changed files

StatusFile+-
M cmd/shithubd-runner/run.go 7 0
M internal/runner/api/client.go 2 0
M internal/runner/api/client_test.go 8 2
M internal/runner/runner.go 12 1
M internal/runner/runner_test.go 21 2
M internal/web/handlers/api/runners.go 89 19
M internal/web/handlers/api/runners_test.go 103 1
cmd/shithubd-runner/run.gomodified
@@ -15,6 +15,7 @@ import (
15
 	runnerconfig "github.com/tenseleyFlow/shithub/internal/runner/config"
15
 	runnerconfig "github.com/tenseleyFlow/shithub/internal/runner/config"
16
 	"github.com/tenseleyFlow/shithub/internal/runner/engine"
16
 	"github.com/tenseleyFlow/shithub/internal/runner/engine"
17
 	"github.com/tenseleyFlow/shithub/internal/runner/workspace"
17
 	"github.com/tenseleyFlow/shithub/internal/runner/workspace"
18
+	"github.com/tenseleyFlow/shithub/internal/version"
18
 )
19
 )
19
 
20
 
20
 var runConfigPath string
21
 var runConfigPath string
@@ -52,6 +53,10 @@ var runCmd = &cobra.Command{
52
 		if err != nil {
53
 		if err != nil {
53
 			return err
54
 			return err
54
 		}
55
 		}
56
+		hostName, err := os.Hostname()
57
+		if err != nil {
58
+			hostName = ""
59
+		}
55
 		execEngine := engine.NewDocker(engine.DockerConfig{
60
 		execEngine := engine.NewDocker(engine.DockerConfig{
56
 			Binary:         cfg.Engine.Kind,
61
 			Binary:         cfg.Engine.Kind,
57
 			DefaultImage:   cfg.Engine.DefaultImage,
62
 			DefaultImage:   cfg.Engine.DefaultImage,
@@ -73,6 +78,8 @@ var runCmd = &cobra.Command{
73
 			Logger:       logger,
78
 			Logger:       logger,
74
 			Labels:       cfg.Runner.Labels,
79
 			Labels:       cfg.Runner.Labels,
75
 			Capacity:     cfg.Runner.Capacity,
80
 			Capacity:     cfg.Runner.Capacity,
81
+			HostName:     hostName,
82
+			Version:      version.String(),
76
 			PollInterval: cfg.Runner.PollInterval,
83
 			PollInterval: cfg.Runner.PollInterval,
77
 			DefaultImage: cfg.Engine.DefaultImage,
84
 			DefaultImage: cfg.Engine.DefaultImage,
78
 			Clock:        func() time.Time { return time.Now().UTC() },
85
 			Clock:        func() time.Time { return time.Now().UTC() },
internal/runner/api/client.gomodified
@@ -47,6 +47,8 @@ func New(cfg Config) (*Client, error) {
47
 type HeartbeatRequest struct {
47
 type HeartbeatRequest struct {
48
 	Labels   []string `json:"labels"`
48
 	Labels   []string `json:"labels"`
49
 	Capacity int      `json:"capacity"`
49
 	Capacity int      `json:"capacity"`
50
+	HostName string   `json:"host_name,omitempty"`
51
+	Version  string   `json:"version,omitempty"`
50
 }
52
 }
51
 
53
 
52
 type Claim struct {
54
 type Claim struct {
internal/runner/api/client_test.gomodified
@@ -24,7 +24,8 @@ func TestHeartbeat_ClaimsJob(t *testing.T) {
24
 		if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
24
 		if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
25
 			t.Fatalf("Decode: %v", err)
25
 			t.Fatalf("Decode: %v", err)
26
 		}
26
 		}
27
-		if req.Capacity != 2 || strings.Join(req.Labels, ",") != "self-hosted,linux" {
27
+		if req.Capacity != 2 || strings.Join(req.Labels, ",") != "self-hosted,linux" ||
28
+			req.HostName != "runner-host" || req.Version != "dev-test" {
28
 			t.Fatalf("request: %#v", req)
29
 			t.Fatalf("request: %#v", req)
29
 		}
30
 		}
30
 		w.Header().Set("Content-Type", "application/json")
31
 		w.Header().Set("Content-Type", "application/json")
@@ -40,7 +41,12 @@ func TestHeartbeat_ClaimsJob(t *testing.T) {
40
 	if err != nil {
41
 	if err != nil {
41
 		t.Fatalf("New: %v", err)
42
 		t.Fatalf("New: %v", err)
42
 	}
43
 	}
43
-	claim, err := client.Heartbeat(t.Context(), HeartbeatRequest{Labels: []string{"self-hosted", "linux"}, Capacity: 2})
44
+	claim, err := client.Heartbeat(t.Context(), HeartbeatRequest{
45
+		Labels:   []string{"self-hosted", "linux"},
46
+		Capacity: 2,
47
+		HostName: "runner-host",
48
+		Version:  "dev-test",
49
+	})
44
 	if err != nil {
50
 	if err != nil {
45
 		t.Fatalf("Heartbeat: %v", err)
51
 		t.Fatalf("Heartbeat: %v", err)
46
 	}
52
 	}
internal/runner/runner.gomodified
@@ -39,6 +39,8 @@ type Options struct {
39
 	Logger             *slog.Logger
39
 	Logger             *slog.Logger
40
 	Labels             []string
40
 	Labels             []string
41
 	Capacity           int
41
 	Capacity           int
42
+	HostName           string
43
+	Version            string
42
 	PollInterval       time.Duration
44
 	PollInterval       time.Duration
43
 	CancelPollInterval time.Duration
45
 	CancelPollInterval time.Duration
44
 	DefaultImage       string
46
 	DefaultImage       string
@@ -53,6 +55,8 @@ type Runner struct {
53
 	logger             *slog.Logger
55
 	logger             *slog.Logger
54
 	labels             []string
56
 	labels             []string
55
 	capacity           int
57
 	capacity           int
58
+	hostName           string
59
+	version            string
56
 	pollInterval       time.Duration
60
 	pollInterval       time.Duration
57
 	cancelPollInterval time.Duration
61
 	cancelPollInterval time.Duration
58
 	defaultImage       string
62
 	defaultImage       string
@@ -92,6 +96,8 @@ func New(opts Options) *Runner {
92
 		logger:             logger,
96
 		logger:             logger,
93
 		labels:             append([]string{}, opts.Labels...),
97
 		labels:             append([]string{}, opts.Labels...),
94
 		capacity:           capacity,
98
 		capacity:           capacity,
99
+		hostName:           opts.HostName,
100
+		version:            opts.Version,
95
 		pollInterval:       poll,
101
 		pollInterval:       poll,
96
 		cancelPollInterval: cancelPoll,
102
 		cancelPollInterval: cancelPoll,
97
 		defaultImage:       opts.DefaultImage,
103
 		defaultImage:       opts.DefaultImage,
@@ -119,7 +125,12 @@ func (r *Runner) Run(ctx context.Context) error {
119
 }
125
 }
120
 
126
 
121
 func (r *Runner) RunOnce(ctx context.Context) (bool, error) {
127
 func (r *Runner) RunOnce(ctx context.Context) (bool, error) {
122
-	claim, err := r.api.Heartbeat(ctx, api.HeartbeatRequest{Labels: r.labels, Capacity: r.capacity})
128
+	claim, err := r.api.Heartbeat(ctx, api.HeartbeatRequest{
129
+		Labels:   r.labels,
130
+		Capacity: r.capacity,
131
+		HostName: r.hostName,
132
+		Version:  r.version,
133
+	})
123
 	if err != nil {
134
 	if err != nil {
124
 		return false, err
135
 		return false, err
125
 	}
136
 	}
internal/runner/runner_test.gomodified
@@ -15,6 +15,7 @@ import (
15
 
15
 
16
 type fakeAPI struct {
16
 type fakeAPI struct {
17
 	claim        *api.Claim
17
 	claim        *api.Claim
18
+	heartbeats   []api.HeartbeatRequest
18
 	statuses     []api.StatusRequest
19
 	statuses     []api.StatusRequest
19
 	stepStatuses []api.StatusRequest
20
 	stepStatuses []api.StatusRequest
20
 	logs         []api.LogRequest
21
 	logs         []api.LogRequest
@@ -24,7 +25,8 @@ type fakeAPI struct {
24
 	next         int
25
 	next         int
25
 }
26
 }
26
 
27
 
27
-func (f *fakeAPI) Heartbeat(_ context.Context, _ api.HeartbeatRequest) (*api.Claim, error) {
28
+func (f *fakeAPI) Heartbeat(_ context.Context, req api.HeartbeatRequest) (*api.Claim, error) {
29
+	f.heartbeats = append(f.heartbeats, req)
28
 	return f.claim, nil
30
 	return f.claim, nil
29
 }
31
 }
30
 
32
 
@@ -160,7 +162,16 @@ func (f *fakeWorkspaces) Remove(_, _ int64) error {
160
 
162
 
161
 func TestRunOnce_NoClaim(t *testing.T) {
163
 func TestRunOnce_NoClaim(t *testing.T) {
162
 	t.Parallel()
164
 	t.Parallel()
163
-	r := New(Options{API: &fakeAPI{}, Engine: &fakeEngine{}, Workspaces: &fakeWorkspaces{}})
165
+	fapi := &fakeAPI{}
166
+	r := New(Options{
167
+		API:        fapi,
168
+		Engine:     &fakeEngine{},
169
+		Workspaces: &fakeWorkspaces{},
170
+		Labels:     []string{"self-hosted", "linux"},
171
+		Capacity:   2,
172
+		HostName:   "runner-host",
173
+		Version:    "dev-test",
174
+	})
164
 	claimed, err := r.RunOnce(t.Context())
175
 	claimed, err := r.RunOnce(t.Context())
165
 	if err != nil {
176
 	if err != nil {
166
 		t.Fatalf("RunOnce: %v", err)
177
 		t.Fatalf("RunOnce: %v", err)
@@ -168,6 +179,14 @@ func TestRunOnce_NoClaim(t *testing.T) {
168
 	if claimed {
179
 	if claimed {
169
 		t.Fatal("claimed = true")
180
 		t.Fatal("claimed = true")
170
 	}
181
 	}
182
+	if len(fapi.heartbeats) != 1 {
183
+		t.Fatalf("heartbeats: %#v", fapi.heartbeats)
184
+	}
185
+	got := fapi.heartbeats[0]
186
+	if got.Capacity != 2 || got.HostName != "runner-host" || got.Version != "dev-test" ||
187
+		len(got.Labels) != 2 || got.Labels[0] != "self-hosted" || got.Labels[1] != "linux" {
188
+		t.Fatalf("heartbeat: %#v", got)
189
+	}
171
 }
190
 }
172
 
191
 
173
 func TestRunOnce_ExecutesAndCompletesSuccess(t *testing.T) {
192
 func TestRunOnce_ExecutesAndCompletesSuccess(t *testing.T) {
internal/web/handlers/api/runners.gomodified
@@ -16,6 +16,7 @@ import (
16
 	"strconv"
16
 	"strconv"
17
 	"strings"
17
 	"strings"
18
 	"time"
18
 	"time"
19
+	"unicode/utf8"
19
 
20
 
20
 	"github.com/go-chi/chi/v5"
21
 	"github.com/go-chi/chi/v5"
21
 	"github.com/jackc/pgx/v5"
22
 	"github.com/jackc/pgx/v5"
@@ -55,8 +56,14 @@ func (h *Handlers) mountRunners(r chi.Router) {
55
 type runnerHeartbeatRequest struct {
56
 type runnerHeartbeatRequest struct {
56
 	Labels   []string `json:"labels"`
57
 	Labels   []string `json:"labels"`
57
 	Capacity int      `json:"capacity"`
58
 	Capacity int      `json:"capacity"`
59
+	HostName string   `json:"host_name"`
60
+	Version  string   `json:"version"`
58
 }
61
 }
59
 
62
 
63
+const runnerMetadataMaxBytes = 255
64
+
65
+var errRunnerRevoked = errors.New("runner is revoked")
66
+
60
 func (h *Handlers) runnerHeartbeat(w http.ResponseWriter, r *http.Request) {
67
 func (h *Handlers) runnerHeartbeat(w http.ResponseWriter, r *http.Request) {
61
 	if h.d.RunnerJWT == nil {
68
 	if h.d.RunnerJWT == nil {
62
 		writeAPIError(w, http.StatusServiceUnavailable, "runner API is not configured")
69
 		writeAPIError(w, http.StatusServiceUnavailable, "runner API is not configured")
@@ -94,9 +101,22 @@ func (h *Handlers) runnerHeartbeat(w http.ResponseWriter, r *http.Request) {
94
 		writeAPIError(w, http.StatusBadRequest, "capacity must be between 1 and 64")
101
 		writeAPIError(w, http.StatusBadRequest, "capacity must be between 1 and 64")
95
 		return
102
 		return
96
 	}
103
 	}
104
+	hostName := cleanRunnerMetadata(body.HostName)
105
+	if hostName == "" {
106
+		hostName = runner.HostName
107
+	}
108
+	version := cleanRunnerMetadata(body.Version)
109
+	if version == "" {
110
+		version = runner.Version
111
+	}
97
 
112
 
98
-	job, steps, resolvedSecrets, claimed, err := h.claimRunnerJob(r.Context(), runner.ID, labels, int32(capacity))
113
+	job, steps, resolvedSecrets, claimed, err := h.claimRunnerJob(r.Context(), runner.ID, labels, int32(capacity), hostName, version)
99
 	if err != nil {
114
 	if err != nil {
115
+		if errors.Is(err, errRunnerRevoked) {
116
+			metrics.ActionsRunnerHeartbeatsTotal.WithLabelValues("rejected").Inc()
117
+			writeAPIError(w, http.StatusUnauthorized, "runner revoked")
118
+			return
119
+		}
100
 		h.d.Logger.ErrorContext(r.Context(), "runner heartbeat claim failed", "runner_id", runner.ID, "error", err)
120
 		h.d.Logger.ErrorContext(r.Context(), "runner heartbeat claim failed", "runner_id", runner.ID, "error", err)
101
 		writeAPIError(w, http.StatusInternalServerError, "runner heartbeat failed")
121
 		writeAPIError(w, http.StatusInternalServerError, "runner heartbeat failed")
102
 		return
122
 		return
@@ -136,6 +156,22 @@ func (h *Handlers) runnerHeartbeat(w http.ResponseWriter, r *http.Request) {
136
 	writeJSON(w, http.StatusOK, h.presentRunnerClaim(job, steps, resolvedSecrets, token, checkoutToken, time.Unix(claims.Exp, 0)))
156
 	writeJSON(w, http.StatusOK, h.presentRunnerClaim(job, steps, resolvedSecrets, token, checkoutToken, time.Unix(claims.Exp, 0)))
137
 }
157
 }
138
 
158
 
159
+func cleanRunnerMetadata(value string) string {
160
+	value = strings.TrimSpace(value)
161
+	if len(value) <= runnerMetadataMaxBytes {
162
+		return value
163
+	}
164
+	var b strings.Builder
165
+	for _, r := range value {
166
+		runeLen := utf8.RuneLen(r)
167
+		if runeLen < 0 || b.Len()+runeLen > runnerMetadataMaxBytes {
168
+			break
169
+		}
170
+		b.WriteRune(r)
171
+	}
172
+	return strings.TrimSpace(b.String())
173
+}
174
+
139
 func (h *Handlers) authenticateRunner(w http.ResponseWriter, r *http.Request) (actionsdb.GetRunnerByTokenHashRow, bool) {
175
 func (h *Handlers) authenticateRunner(w http.ResponseWriter, r *http.Request) (actionsdb.GetRunnerByTokenHashRow, bool) {
140
 	const prefix = "Bearer "
176
 	const prefix = "Bearer "
141
 	authz := r.Header.Get("Authorization")
177
 	authz := r.Header.Get("Authorization")
@@ -178,6 +214,8 @@ func (h *Handlers) claimRunnerJob(
178
 	runnerID int64,
214
 	runnerID int64,
179
 	labels []string,
215
 	labels []string,
180
 	capacity int32,
216
 	capacity int32,
217
+	hostName string,
218
+	version string,
181
 ) (actionsdb.ClaimQueuedWorkflowJobRow, []actionsdb.ListRunnerStepsForJobRow, map[string]string, bool, error) {
219
 ) (actionsdb.ClaimQueuedWorkflowJobRow, []actionsdb.ListRunnerStepsForJobRow, map[string]string, bool, error) {
182
 	q := actionsdb.New()
220
 	q := actionsdb.New()
183
 	tx, err := h.d.Pool.Begin(ctx)
221
 	tx, err := h.d.Pool.Begin(ctx)
@@ -191,20 +229,44 @@ func (h *Handlers) claimRunnerJob(
191
 		}
229
 		}
192
 	}()
230
 	}()
193
 
231
 
194
-	if _, err := q.LockRunnerByID(ctx, tx, runnerID); err != nil {
232
+	lockedRunner, err := q.LockRunnerByID(ctx, tx, runnerID)
233
+	if err != nil {
195
 		return actionsdb.ClaimQueuedWorkflowJobRow{}, nil, nil, false, err
234
 		return actionsdb.ClaimQueuedWorkflowJobRow{}, nil, nil, false, err
196
 	}
235
 	}
236
+	if lockedRunner.RevokedAt.Valid {
237
+		return actionsdb.ClaimQueuedWorkflowJobRow{}, nil, nil, false, errRunnerRevoked
238
+	}
197
 	running, err := q.CountRunningJobsForRunner(ctx, tx, runnerID)
239
 	running, err := q.CountRunningJobsForRunner(ctx, tx, runnerID)
198
 	if err != nil {
240
 	if err != nil {
199
 		return actionsdb.ClaimQueuedWorkflowJobRow{}, nil, nil, false, err
241
 		return actionsdb.ClaimQueuedWorkflowJobRow{}, nil, nil, false, err
200
 	}
242
 	}
201
-	if running >= capacity {
243
+	heartbeat := func(status actionsdb.WorkflowRunnerStatus) error {
202
-		if _, err := q.HeartbeatRunner(ctx, tx, actionsdb.HeartbeatRunnerParams{
244
+		_, err := q.HeartbeatRunner(ctx, tx, actionsdb.HeartbeatRunnerParams{
203
 			ID:       runnerID,
245
 			ID:       runnerID,
204
 			Labels:   labels,
246
 			Labels:   labels,
205
 			Capacity: capacity,
247
 			Capacity: capacity,
206
-			Status:   actionsdb.WorkflowRunnerStatusBusy,
248
+			Status:   status,
207
-		}); err != nil {
249
+			HostName: hostName,
250
+			Version:  version,
251
+		})
252
+		return err
253
+	}
254
+	if lockedRunner.DrainingAt.Valid {
255
+		status := actionsdb.WorkflowRunnerStatusIdle
256
+		if running > 0 {
257
+			status = actionsdb.WorkflowRunnerStatusBusy
258
+		}
259
+		if err := heartbeat(status); err != nil {
260
+			return actionsdb.ClaimQueuedWorkflowJobRow{}, nil, nil, false, err
261
+		}
262
+		if err := tx.Commit(ctx); err != nil {
263
+			return actionsdb.ClaimQueuedWorkflowJobRow{}, nil, nil, false, err
264
+		}
265
+		committed = true
266
+		return actionsdb.ClaimQueuedWorkflowJobRow{}, nil, nil, false, nil
267
+	}
268
+	if running >= capacity {
269
+		if err := heartbeat(actionsdb.WorkflowRunnerStatusBusy); err != nil {
208
 			return actionsdb.ClaimQueuedWorkflowJobRow{}, nil, nil, false, err
270
 			return actionsdb.ClaimQueuedWorkflowJobRow{}, nil, nil, false, err
209
 		}
271
 		}
210
 		if err := tx.Commit(ctx); err != nil {
272
 		if err := tx.Commit(ctx); err != nil {
@@ -222,12 +284,7 @@ func (h *Handlers) claimRunnerJob(
222
 		if !errors.Is(err, pgx.ErrNoRows) {
284
 		if !errors.Is(err, pgx.ErrNoRows) {
223
 			return actionsdb.ClaimQueuedWorkflowJobRow{}, nil, nil, false, err
285
 			return actionsdb.ClaimQueuedWorkflowJobRow{}, nil, nil, false, err
224
 		}
286
 		}
225
-		if _, err := q.HeartbeatRunner(ctx, tx, actionsdb.HeartbeatRunnerParams{
287
+		if err := heartbeat(actionsdb.WorkflowRunnerStatusIdle); err != nil {
226
-			ID:       runnerID,
227
-			Labels:   labels,
228
-			Capacity: capacity,
229
-			Status:   actionsdb.WorkflowRunnerStatusIdle,
230
-		}); err != nil {
231
 			return actionsdb.ClaimQueuedWorkflowJobRow{}, nil, nil, false, err
288
 			return actionsdb.ClaimQueuedWorkflowJobRow{}, nil, nil, false, err
232
 		}
289
 		}
233
 		if err := tx.Commit(ctx); err != nil {
290
 		if err := tx.Commit(ctx); err != nil {
@@ -268,18 +325,24 @@ func (h *Handlers) claimRunnerJob(
268
 	if running+1 >= capacity {
325
 	if running+1 >= capacity {
269
 		status = actionsdb.WorkflowRunnerStatusBusy
326
 		status = actionsdb.WorkflowRunnerStatusBusy
270
 	}
327
 	}
271
-	if _, err := q.HeartbeatRunner(ctx, tx, actionsdb.HeartbeatRunnerParams{
328
+	if err := heartbeat(status); err != nil {
272
-		ID:       runnerID,
273
-		Labels:   labels,
274
-		Capacity: capacity,
275
-		Status:   status,
276
-	}); err != nil {
277
 		return actionsdb.ClaimQueuedWorkflowJobRow{}, nil, nil, false, err
329
 		return actionsdb.ClaimQueuedWorkflowJobRow{}, nil, nil, false, err
278
 	}
330
 	}
331
+	var claimLatencySeconds float64
332
+	observeClaimLatency := false
333
+	if job.CreatedAt.Valid {
334
+		if latency := time.Since(job.CreatedAt.Time); latency >= 0 {
335
+			claimLatencySeconds = latency.Seconds()
336
+			observeClaimLatency = true
337
+		}
338
+	}
279
 	if err := tx.Commit(ctx); err != nil {
339
 	if err := tx.Commit(ctx); err != nil {
280
 		return actionsdb.ClaimQueuedWorkflowJobRow{}, nil, nil, false, err
340
 		return actionsdb.ClaimQueuedWorkflowJobRow{}, nil, nil, false, err
281
 	}
341
 	}
282
 	committed = true
342
 	committed = true
343
+	if observeClaimLatency {
344
+		metrics.ActionsJobClaimLatencySeconds.Observe(claimLatencySeconds)
345
+	}
283
 	return job, steps, resolvedSecrets, true, nil
346
 	return job, steps, resolvedSecrets, true, nil
284
 }
347
 }
285
 
348
 
@@ -326,7 +389,14 @@ func (h *Handlers) authenticateRunnerJob(w http.ResponseWriter, r *http.Request)
326
 		writeAPIError(w, http.StatusUnauthorized, "job token invalid")
389
 		writeAPIError(w, http.StatusUnauthorized, "job token invalid")
327
 		return runnerJobAuth{}, false
390
 		return runnerJobAuth{}, false
328
 	}
391
 	}
329
-	job, err := actionsdb.New().GetWorkflowJobByID(r.Context(), h.d.Pool, pathJobID)
392
+	q := actionsdb.New()
393
+	runner, err := q.GetRunnerByID(r.Context(), h.d.Pool, runnerID)
394
+	if err != nil || runner.RevokedAt.Valid {
395
+		metrics.ActionsRunnerJWTTotal.WithLabelValues("rejected").Inc()
396
+		writeAPIError(w, http.StatusUnauthorized, "job token invalid")
397
+		return runnerJobAuth{}, false
398
+	}
399
+	job, err := q.GetWorkflowJobByID(r.Context(), h.d.Pool, pathJobID)
330
 	if err != nil {
400
 	if err != nil {
331
 		if errors.Is(err, pgx.ErrNoRows) {
401
 		if errors.Is(err, pgx.ErrNoRows) {
332
 			writeAPIError(w, http.StatusNotFound, "job not found")
402
 			writeAPIError(w, http.StatusNotFound, "job not found")
internal/web/handlers/api/runners_test.gomodified
@@ -59,7 +59,7 @@ func TestRunnerHeartbeatClaimsQueuedJob(t *testing.T) {
59
 	router := newRunnerAPIRouter(t, pool, logger, signer)
59
 	router := newRunnerAPIRouter(t, pool, logger, signer)
60
 
60
 
61
 	req := httptest.NewRequest(http.MethodPost, "/api/v1/runners/heartbeat",
61
 	req := httptest.NewRequest(http.MethodPost, "/api/v1/runners/heartbeat",
62
-		strings.NewReader(`{"labels":["ubuntu-latest","linux"],"capacity":1}`))
62
+		strings.NewReader(`{"labels":["ubuntu-latest","linux"],"capacity":1,"host_name":"runner-host","version":"dev-test"}`))
63
 	req.Header.Set("Authorization", "Bearer "+token)
63
 	req.Header.Set("Authorization", "Bearer "+token)
64
 	rr := httptest.NewRecorder()
64
 	rr := httptest.NewRecorder()
65
 	router.ServeHTTP(rr, req)
65
 	router.ServeHTTP(rr, req)
@@ -122,6 +122,13 @@ func TestRunnerHeartbeatClaimsQueuedJob(t *testing.T) {
122
 		checkoutClaims.Purpose != runnerjwt.PurposeCheckout {
122
 		checkoutClaims.Purpose != runnerjwt.PurposeCheckout {
123
 		t.Fatalf("checkout claims/job mismatch: claims=%+v job=%+v", checkoutClaims, resp.Job)
123
 		t.Fatalf("checkout claims/job mismatch: claims=%+v job=%+v", checkoutClaims, resp.Job)
124
 	}
124
 	}
125
+	runnerRow, err := actionsdb.New().GetRunnerByID(ctx, pool, runnerID)
126
+	if err != nil {
127
+		t.Fatalf("GetRunnerByID: %v", err)
128
+	}
129
+	if runnerRow.HostName != "runner-host" || runnerRow.Version != "dev-test" {
130
+		t.Fatalf("runner metadata: host=%q version=%q", runnerRow.HostName, runnerRow.Version)
131
+	}
125
 
132
 
126
 	var logResp struct {
133
 	var logResp struct {
127
 		Accepted  bool   `json:"accepted"`
134
 		Accepted  bool   `json:"accepted"`
@@ -189,6 +196,101 @@ func TestRunnerHeartbeatClaimsQueuedJob(t *testing.T) {
189
 	}
196
 	}
190
 }
197
 }
191
 
198
 
199
+func TestRunnerHeartbeatDoesNotClaimWhenDraining(t *testing.T) {
200
+	ctx := context.Background()
201
+	pool := dbtest.NewTestDB(t)
202
+	logger := slog.New(slog.NewTextHandler(io.Discard, nil))
203
+	repoID, userID := setupRunnerAPIRepo(t, pool)
204
+	runID := enqueueRunnerAPIRun(t, pool, logger, repoID, userID)
205
+	token, runnerID := registerRunnerForTest(t, pool, []string{"ubuntu-latest", "linux"}, 1)
206
+	q := actionsdb.New()
207
+	if _, err := q.SetRunnerDraining(ctx, pool, actionsdb.SetRunnerDrainingParams{
208
+		ID:          runnerID,
209
+		DrainReason: "maintenance",
210
+	}); err != nil {
211
+		t.Fatalf("SetRunnerDraining: %v", err)
212
+	}
213
+	router := newRunnerAPIRouter(t, pool, logger, runnerAPISigner(t, time.Now()))
214
+
215
+	req := httptest.NewRequest(http.MethodPost, "/api/v1/runners/heartbeat",
216
+		strings.NewReader(`{"labels":["ubuntu-latest","linux"],"capacity":1,"host_name":"draining-host","version":"dev-test"}`))
217
+	req.Header.Set("Authorization", "Bearer "+token)
218
+	rr := httptest.NewRecorder()
219
+	router.ServeHTTP(rr, req)
220
+
221
+	if rr.Code != http.StatusNoContent {
222
+		t.Fatalf("status: got %d, want 204; body=%s", rr.Code, rr.Body.String())
223
+	}
224
+	jobs, err := q.ListJobsForRun(ctx, pool, runID)
225
+	if err != nil {
226
+		t.Fatalf("ListJobsForRun: %v", err)
227
+	}
228
+	if len(jobs) != 1 || jobs[0].Status != actionsdb.WorkflowJobStatusQueued {
229
+		t.Fatalf("job was claimed while runner drained: %#v", jobs)
230
+	}
231
+	job, err := q.GetWorkflowJobByID(ctx, pool, jobs[0].ID)
232
+	if err != nil {
233
+		t.Fatalf("GetWorkflowJobByID: %v", err)
234
+	}
235
+	if job.RunnerID.Valid {
236
+		t.Fatalf("job was assigned to runner while drained: %+v", job)
237
+	}
238
+	runnerRow, err := q.GetRunnerByID(ctx, pool, runnerID)
239
+	if err != nil {
240
+		t.Fatalf("GetRunnerByID: %v", err)
241
+	}
242
+	if !runnerRow.DrainingAt.Valid || runnerRow.HostName != "draining-host" {
243
+		t.Fatalf("runner drain/metadata not preserved: %+v", runnerRow)
244
+	}
245
+}
246
+
247
+func TestRunnerJobTokenRejectedAfterRunnerRevoked(t *testing.T) {
248
+	ctx := context.Background()
249
+	pool := dbtest.NewTestDB(t)
250
+	logger := slog.New(slog.NewTextHandler(io.Discard, nil))
251
+	repoID, userID := setupRunnerAPIRepo(t, pool)
252
+	enqueueRunnerAPIRun(t, pool, logger, repoID, userID)
253
+	token, runnerID := registerRunnerForTest(t, pool, []string{"ubuntu-latest", "linux"}, 1)
254
+	router := newRunnerAPIRouter(t, pool, logger, runnerAPISigner(t, time.Now()))
255
+
256
+	req := httptest.NewRequest(http.MethodPost, "/api/v1/runners/heartbeat",
257
+		strings.NewReader(`{"labels":["ubuntu-latest","linux"],"capacity":1}`))
258
+	req.Header.Set("Authorization", "Bearer "+token)
259
+	rr := httptest.NewRecorder()
260
+	router.ServeHTTP(rr, req)
261
+	if rr.Code != http.StatusOK {
262
+		t.Fatalf("claim status: got %d, want 200; body=%s", rr.Code, rr.Body.String())
263
+	}
264
+	var claim struct {
265
+		Token string `json:"token"`
266
+		Job   struct {
267
+			ID int64 `json:"id"`
268
+		} `json:"job"`
269
+	}
270
+	if err := json.Unmarshal(rr.Body.Bytes(), &claim); err != nil {
271
+		t.Fatalf("decode claim: %v", err)
272
+	}
273
+	q := actionsdb.New()
274
+	if _, err := q.RevokeRunner(ctx, pool, actionsdb.RevokeRunnerParams{
275
+		ID:            runnerID,
276
+		RevokedReason: "compromised",
277
+	}); err != nil {
278
+		t.Fatalf("RevokeRunner: %v", err)
279
+	}
280
+	if err := q.RevokeAllTokensForRunner(ctx, pool, runnerID); err != nil {
281
+		t.Fatalf("RevokeAllTokensForRunner: %v", err)
282
+	}
283
+
284
+	statusReq := httptest.NewRequest(http.MethodPost, fmt.Sprintf("/api/v1/jobs/%d/status", claim.Job.ID),
285
+		strings.NewReader(`{"status":"running"}`))
286
+	statusReq.Header.Set("Authorization", "Bearer "+claim.Token)
287
+	statusRR := httptest.NewRecorder()
288
+	router.ServeHTTP(statusRR, statusReq)
289
+	if statusRR.Code != http.StatusUnauthorized {
290
+		t.Fatalf("status: got %d, want 401; body=%s", statusRR.Code, statusRR.Body.String())
291
+	}
292
+}
293
+
192
 func TestRunnerHeartbeatBypassesGlobalAnonAPILimit(t *testing.T) {
294
 func TestRunnerHeartbeatBypassesGlobalAnonAPILimit(t *testing.T) {
193
 	pool := dbtest.NewTestDB(t)
295
 	pool := dbtest.NewTestDB(t)
194
 	logger := slog.New(slog.NewTextHandler(io.Discard, nil))
296
 	logger := slog.New(slog.NewTextHandler(io.Discard, nil))