Go · 8108 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 // Package runnerjwt signs and verifies the short-lived job tokens used by
4 // shithub Actions runners.
5 //
6 // Registration tokens authenticate a runner to the heartbeat endpoint. A
7 // successful claim receives one JWT scoped to one workflow_jobs row; job
8 // endpoints verify the signature, expiry, path/job match, and then consume
9 // the jti through runner_jwt_used so the token is single-use.
10 package runnerjwt
11
12 import (
13 "crypto/hkdf"
14 "crypto/hmac"
15 "crypto/rand"
16 "crypto/sha256"
17 "encoding/base64"
18 "encoding/json"
19 "errors"
20 "fmt"
21 "io"
22 "strconv"
23 "strings"
24 "time"
25 )
26
27 const (
28 // DefaultTTL is the runner job-token lifetime from the S41c contract.
29 DefaultTTL = 15 * time.Minute
30
31 PurposeAPI = "api"
32 PurposeCheckout = "checkout"
33
34 signingKeySize = 32
35 hkdfInfo = "actions-runner-jwt-v1"
36 jtiBytes = 32
37 )
38
39 var (
40 ErrEmptyKey = errors.New("runnerjwt: empty key")
41 ErrInvalidKey = errors.New("runnerjwt: key must be 32 bytes")
42 ErrMalformed = errors.New("runnerjwt: malformed token")
43 ErrInvalidSignature = errors.New("runnerjwt: invalid signature")
44 ErrExpired = errors.New("runnerjwt: expired token")
45 ErrInvalidClaims = errors.New("runnerjwt: invalid claims")
46 ErrUnsupportedHeader = errors.New("runnerjwt: unsupported header")
47 )
48
49 // Claims are the JWT payload fields accepted by runner job endpoints.
50 type Claims struct {
51 Sub string `json:"sub"`
52 JobID int64 `json:"job_id"`
53 RunID int64 `json:"run_id"`
54 RepoID int64 `json:"repo_id"`
55 Exp int64 `json:"exp"`
56 JTI string `json:"jti"`
57 Purpose string `json:"purpose,omitempty"`
58 }
59
60 // RunnerID extracts the runner id encoded in sub="runner:<id>".
61 func (c Claims) RunnerID() (int64, error) {
62 const prefix = "runner:"
63 if !strings.HasPrefix(c.Sub, prefix) {
64 return 0, ErrInvalidClaims
65 }
66 id, err := strconv.ParseInt(strings.TrimPrefix(c.Sub, prefix), 10, 64)
67 if err != nil || id <= 0 {
68 return 0, ErrInvalidClaims
69 }
70 return id, nil
71 }
72
73 // MintParams describes a job token to issue.
74 type MintParams struct {
75 RunnerID int64
76 JobID int64
77 RunID int64
78 RepoID int64
79 TTL time.Duration
80 Purpose string
81 }
82
83 // Signer signs and verifies HS256 runner JWTs.
84 type Signer struct {
85 key []byte
86 now func() time.Time
87 rng io.Reader
88 }
89
90 // Option customizes a Signer. Tests use these for deterministic time/randomness.
91 type Option func(*Signer)
92
93 // WithClock overrides the clock used for exp validation and issuance.
94 func WithClock(now func() time.Time) Option {
95 return func(s *Signer) {
96 if now != nil {
97 s.now = now
98 }
99 }
100 }
101
102 // WithRand overrides the random source used for jti generation.
103 func WithRand(r io.Reader) Option {
104 return func(s *Signer) {
105 if r != nil {
106 s.rng = r
107 }
108 }
109 }
110
111 // NewFromTOTPKeyB64 decodes cfg.Auth.TOTPKeyB64 and derives an isolated
112 // runner-JWT signing key via HKDF. The raw TOTP/secretbox key is never used
113 // directly for JWT signatures.
114 func NewFromTOTPKeyB64(totpKeyB64 string, opts ...Option) (*Signer, error) {
115 key, err := DeriveKeyFromTOTPKeyB64(totpKeyB64)
116 if err != nil {
117 return nil, err
118 }
119 return NewFromKey(key, opts...)
120 }
121
122 // DeriveKeyFromTOTPKeyB64 returns the HS256 key derived from the configured
123 // 32-byte TOTP/secretbox key.
124 func DeriveKeyFromTOTPKeyB64(totpKeyB64 string) ([]byte, error) {
125 if totpKeyB64 == "" {
126 return nil, ErrEmptyKey
127 }
128 raw, err := decodeKey(totpKeyB64)
129 if err != nil {
130 return nil, fmt.Errorf("runnerjwt: decode key: %w", err)
131 }
132 if len(raw) != signingKeySize {
133 return nil, ErrInvalidKey
134 }
135 key, err := hkdf.Key(sha256.New, raw, nil, hkdfInfo, signingKeySize)
136 if err != nil {
137 return nil, fmt.Errorf("runnerjwt: derive key: %w", err)
138 }
139 return key, nil
140 }
141
142 // NewFromKey constructs a Signer from an already-derived 32-byte HS256 key.
143 func NewFromKey(key []byte, opts ...Option) (*Signer, error) {
144 if len(key) != signingKeySize {
145 return nil, ErrInvalidKey
146 }
147 copied := make([]byte, len(key))
148 copy(copied, key)
149 s := &Signer{
150 key: copied,
151 now: time.Now,
152 rng: rand.Reader,
153 }
154 for _, opt := range opts {
155 opt(s)
156 }
157 return s, nil
158 }
159
160 // Mint signs a new job token and returns the token plus the exact claims.
161 func (s *Signer) Mint(p MintParams) (string, Claims, error) {
162 ttl := p.TTL
163 if ttl == 0 {
164 ttl = DefaultTTL
165 }
166 if p.RunnerID <= 0 || p.JobID <= 0 || p.RunID <= 0 || p.RepoID <= 0 || ttl <= 0 {
167 return "", Claims{}, ErrInvalidClaims
168 }
169 jti, err := newJTI(s.rng)
170 if err != nil {
171 return "", Claims{}, err
172 }
173 purpose := p.Purpose
174 if purpose == "" {
175 purpose = PurposeAPI
176 }
177 claims := Claims{
178 Sub: fmt.Sprintf("runner:%d", p.RunnerID),
179 JobID: p.JobID,
180 RunID: p.RunID,
181 RepoID: p.RepoID,
182 Exp: s.now().Add(ttl).Unix(),
183 JTI: jti,
184 Purpose: purpose,
185 }
186 if err := validateClaims(claims); err != nil {
187 return "", Claims{}, err
188 }
189 token, err := s.sign(claims)
190 if err != nil {
191 return "", Claims{}, err
192 }
193 return token, claims, nil
194 }
195
196 // Verify checks token shape, HS256 signature, registered claims, and expiry.
197 // It does not consume jti; callers perform that DB operation after verifying
198 // path/job ownership.
199 func (s *Signer) Verify(token string) (Claims, error) {
200 parts := strings.Split(token, ".")
201 if len(parts) != 3 || parts[0] == "" || parts[1] == "" || parts[2] == "" {
202 return Claims{}, ErrMalformed
203 }
204 headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
205 if err != nil {
206 return Claims{}, ErrMalformed
207 }
208 var header struct {
209 Alg string `json:"alg"`
210 Typ string `json:"typ"`
211 }
212 if err := json.Unmarshal(headerBytes, &header); err != nil {
213 return Claims{}, ErrMalformed
214 }
215 if header.Alg != "HS256" || header.Typ != "JWT" {
216 return Claims{}, ErrUnsupportedHeader
217 }
218
219 signingInput := parts[0] + "." + parts[1]
220 gotSig, err := base64.RawURLEncoding.DecodeString(parts[2])
221 if err != nil {
222 return Claims{}, ErrMalformed
223 }
224 wantSig := signHS256(s.key, signingInput)
225 if !hmac.Equal(gotSig, wantSig) {
226 return Claims{}, ErrInvalidSignature
227 }
228
229 payloadBytes, err := base64.RawURLEncoding.DecodeString(parts[1])
230 if err != nil {
231 return Claims{}, ErrMalformed
232 }
233 var claims Claims
234 if err := json.Unmarshal(payloadBytes, &claims); err != nil {
235 return Claims{}, ErrMalformed
236 }
237 if err := validateClaims(claims); err != nil {
238 return Claims{}, err
239 }
240 if !s.now().Before(time.Unix(claims.Exp, 0)) {
241 return Claims{}, ErrExpired
242 }
243 return claims, nil
244 }
245
246 func (s *Signer) sign(claims Claims) (string, error) {
247 headerJSON, err := json.Marshal(struct {
248 Alg string `json:"alg"`
249 Typ string `json:"typ"`
250 }{Alg: "HS256", Typ: "JWT"})
251 if err != nil {
252 return "", err
253 }
254 payloadJSON, err := json.Marshal(claims)
255 if err != nil {
256 return "", err
257 }
258 header := base64.RawURLEncoding.EncodeToString(headerJSON)
259 payload := base64.RawURLEncoding.EncodeToString(payloadJSON)
260 signingInput := header + "." + payload
261 sig := base64.RawURLEncoding.EncodeToString(signHS256(s.key, signingInput))
262 return signingInput + "." + sig, nil
263 }
264
265 func signHS256(key []byte, signingInput string) []byte {
266 mac := hmac.New(sha256.New, key)
267 _, _ = mac.Write([]byte(signingInput))
268 return mac.Sum(nil)
269 }
270
271 func newJTI(r io.Reader) (string, error) {
272 buf := make([]byte, jtiBytes)
273 if _, err := io.ReadFull(r, buf); err != nil {
274 return "", fmt.Errorf("runnerjwt: jti: %w", err)
275 }
276 return base64.RawURLEncoding.EncodeToString(buf), nil
277 }
278
279 func validateClaims(c Claims) error {
280 if _, err := c.RunnerID(); err != nil {
281 return err
282 }
283 if c.JobID <= 0 || c.RunID <= 0 || c.RepoID <= 0 || c.Exp <= 0 {
284 return ErrInvalidClaims
285 }
286 switch c.Purpose {
287 case "", PurposeAPI, PurposeCheckout:
288 default:
289 return ErrInvalidClaims
290 }
291 if len(c.JTI) < 16 || len(c.JTI) > 128 {
292 return ErrInvalidClaims
293 }
294 return nil
295 }
296
297 func decodeKey(s string) ([]byte, error) {
298 encodings := []*base64.Encoding{
299 base64.StdEncoding,
300 base64.RawStdEncoding,
301 base64.URLEncoding,
302 base64.RawURLEncoding,
303 }
304 var lastErr error
305 for _, enc := range encodings {
306 raw, err := enc.DecodeString(s)
307 if err == nil {
308 return raw, nil
309 }
310 lastErr = err
311 }
312 return nil, lastErr
313 }
314