| 1 | // SPDX-License-Identifier: AGPL-3.0-or-later |
| 2 | |
| 3 | // Package runnerjwt signs and verifies the short-lived job tokens used by |
| 4 | // shithub Actions runners. |
| 5 | // |
| 6 | // Registration tokens authenticate a runner to the heartbeat endpoint. A |
| 7 | // successful claim receives one JWT scoped to one workflow_jobs row; job |
| 8 | // endpoints verify the signature, expiry, path/job match, and then consume |
| 9 | // the jti through runner_jwt_used so the token is single-use. |
| 10 | package runnerjwt |
| 11 | |
| 12 | import ( |
| 13 | "crypto/hkdf" |
| 14 | "crypto/hmac" |
| 15 | "crypto/rand" |
| 16 | "crypto/sha256" |
| 17 | "encoding/base64" |
| 18 | "encoding/json" |
| 19 | "errors" |
| 20 | "fmt" |
| 21 | "io" |
| 22 | "strconv" |
| 23 | "strings" |
| 24 | "time" |
| 25 | ) |
| 26 | |
| 27 | const ( |
| 28 | // DefaultTTL is the runner job-token lifetime from the S41c contract. |
| 29 | DefaultTTL = 15 * time.Minute |
| 30 | |
| 31 | PurposeAPI = "api" |
| 32 | PurposeCheckout = "checkout" |
| 33 | |
| 34 | signingKeySize = 32 |
| 35 | hkdfInfo = "actions-runner-jwt-v1" |
| 36 | jtiBytes = 32 |
| 37 | ) |
| 38 | |
| 39 | var ( |
| 40 | ErrEmptyKey = errors.New("runnerjwt: empty key") |
| 41 | ErrInvalidKey = errors.New("runnerjwt: key must be 32 bytes") |
| 42 | ErrMalformed = errors.New("runnerjwt: malformed token") |
| 43 | ErrInvalidSignature = errors.New("runnerjwt: invalid signature") |
| 44 | ErrExpired = errors.New("runnerjwt: expired token") |
| 45 | ErrInvalidClaims = errors.New("runnerjwt: invalid claims") |
| 46 | ErrUnsupportedHeader = errors.New("runnerjwt: unsupported header") |
| 47 | ) |
| 48 | |
| 49 | // Claims are the JWT payload fields accepted by runner job endpoints. |
| 50 | type Claims struct { |
| 51 | Sub string `json:"sub"` |
| 52 | JobID int64 `json:"job_id"` |
| 53 | RunID int64 `json:"run_id"` |
| 54 | RepoID int64 `json:"repo_id"` |
| 55 | Exp int64 `json:"exp"` |
| 56 | JTI string `json:"jti"` |
| 57 | Purpose string `json:"purpose,omitempty"` |
| 58 | } |
| 59 | |
| 60 | // RunnerID extracts the runner id encoded in sub="runner:<id>". |
| 61 | func (c Claims) RunnerID() (int64, error) { |
| 62 | const prefix = "runner:" |
| 63 | if !strings.HasPrefix(c.Sub, prefix) { |
| 64 | return 0, ErrInvalidClaims |
| 65 | } |
| 66 | id, err := strconv.ParseInt(strings.TrimPrefix(c.Sub, prefix), 10, 64) |
| 67 | if err != nil || id <= 0 { |
| 68 | return 0, ErrInvalidClaims |
| 69 | } |
| 70 | return id, nil |
| 71 | } |
| 72 | |
| 73 | // MintParams describes a job token to issue. |
| 74 | type MintParams struct { |
| 75 | RunnerID int64 |
| 76 | JobID int64 |
| 77 | RunID int64 |
| 78 | RepoID int64 |
| 79 | TTL time.Duration |
| 80 | Purpose string |
| 81 | } |
| 82 | |
| 83 | // Signer signs and verifies HS256 runner JWTs. |
| 84 | type Signer struct { |
| 85 | key []byte |
| 86 | now func() time.Time |
| 87 | rng io.Reader |
| 88 | } |
| 89 | |
| 90 | // Option customizes a Signer. Tests use these for deterministic time/randomness. |
| 91 | type Option func(*Signer) |
| 92 | |
| 93 | // WithClock overrides the clock used for exp validation and issuance. |
| 94 | func WithClock(now func() time.Time) Option { |
| 95 | return func(s *Signer) { |
| 96 | if now != nil { |
| 97 | s.now = now |
| 98 | } |
| 99 | } |
| 100 | } |
| 101 | |
| 102 | // WithRand overrides the random source used for jti generation. |
| 103 | func WithRand(r io.Reader) Option { |
| 104 | return func(s *Signer) { |
| 105 | if r != nil { |
| 106 | s.rng = r |
| 107 | } |
| 108 | } |
| 109 | } |
| 110 | |
| 111 | // NewFromTOTPKeyB64 decodes cfg.Auth.TOTPKeyB64 and derives an isolated |
| 112 | // runner-JWT signing key via HKDF. The raw TOTP/secretbox key is never used |
| 113 | // directly for JWT signatures. |
| 114 | func NewFromTOTPKeyB64(totpKeyB64 string, opts ...Option) (*Signer, error) { |
| 115 | key, err := DeriveKeyFromTOTPKeyB64(totpKeyB64) |
| 116 | if err != nil { |
| 117 | return nil, err |
| 118 | } |
| 119 | return NewFromKey(key, opts...) |
| 120 | } |
| 121 | |
| 122 | // DeriveKeyFromTOTPKeyB64 returns the HS256 key derived from the configured |
| 123 | // 32-byte TOTP/secretbox key. |
| 124 | func DeriveKeyFromTOTPKeyB64(totpKeyB64 string) ([]byte, error) { |
| 125 | if totpKeyB64 == "" { |
| 126 | return nil, ErrEmptyKey |
| 127 | } |
| 128 | raw, err := decodeKey(totpKeyB64) |
| 129 | if err != nil { |
| 130 | return nil, fmt.Errorf("runnerjwt: decode key: %w", err) |
| 131 | } |
| 132 | if len(raw) != signingKeySize { |
| 133 | return nil, ErrInvalidKey |
| 134 | } |
| 135 | key, err := hkdf.Key(sha256.New, raw, nil, hkdfInfo, signingKeySize) |
| 136 | if err != nil { |
| 137 | return nil, fmt.Errorf("runnerjwt: derive key: %w", err) |
| 138 | } |
| 139 | return key, nil |
| 140 | } |
| 141 | |
| 142 | // NewFromKey constructs a Signer from an already-derived 32-byte HS256 key. |
| 143 | func NewFromKey(key []byte, opts ...Option) (*Signer, error) { |
| 144 | if len(key) != signingKeySize { |
| 145 | return nil, ErrInvalidKey |
| 146 | } |
| 147 | copied := make([]byte, len(key)) |
| 148 | copy(copied, key) |
| 149 | s := &Signer{ |
| 150 | key: copied, |
| 151 | now: time.Now, |
| 152 | rng: rand.Reader, |
| 153 | } |
| 154 | for _, opt := range opts { |
| 155 | opt(s) |
| 156 | } |
| 157 | return s, nil |
| 158 | } |
| 159 | |
| 160 | // Mint signs a new job token and returns the token plus the exact claims. |
| 161 | func (s *Signer) Mint(p MintParams) (string, Claims, error) { |
| 162 | ttl := p.TTL |
| 163 | if ttl == 0 { |
| 164 | ttl = DefaultTTL |
| 165 | } |
| 166 | if p.RunnerID <= 0 || p.JobID <= 0 || p.RunID <= 0 || p.RepoID <= 0 || ttl <= 0 { |
| 167 | return "", Claims{}, ErrInvalidClaims |
| 168 | } |
| 169 | jti, err := newJTI(s.rng) |
| 170 | if err != nil { |
| 171 | return "", Claims{}, err |
| 172 | } |
| 173 | purpose := p.Purpose |
| 174 | if purpose == "" { |
| 175 | purpose = PurposeAPI |
| 176 | } |
| 177 | claims := Claims{ |
| 178 | Sub: fmt.Sprintf("runner:%d", p.RunnerID), |
| 179 | JobID: p.JobID, |
| 180 | RunID: p.RunID, |
| 181 | RepoID: p.RepoID, |
| 182 | Exp: s.now().Add(ttl).Unix(), |
| 183 | JTI: jti, |
| 184 | Purpose: purpose, |
| 185 | } |
| 186 | if err := validateClaims(claims); err != nil { |
| 187 | return "", Claims{}, err |
| 188 | } |
| 189 | token, err := s.sign(claims) |
| 190 | if err != nil { |
| 191 | return "", Claims{}, err |
| 192 | } |
| 193 | return token, claims, nil |
| 194 | } |
| 195 | |
| 196 | // Verify checks token shape, HS256 signature, registered claims, and expiry. |
| 197 | // It does not consume jti; callers perform that DB operation after verifying |
| 198 | // path/job ownership. |
| 199 | func (s *Signer) Verify(token string) (Claims, error) { |
| 200 | parts := strings.Split(token, ".") |
| 201 | if len(parts) != 3 || parts[0] == "" || parts[1] == "" || parts[2] == "" { |
| 202 | return Claims{}, ErrMalformed |
| 203 | } |
| 204 | headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0]) |
| 205 | if err != nil { |
| 206 | return Claims{}, ErrMalformed |
| 207 | } |
| 208 | var header struct { |
| 209 | Alg string `json:"alg"` |
| 210 | Typ string `json:"typ"` |
| 211 | } |
| 212 | if err := json.Unmarshal(headerBytes, &header); err != nil { |
| 213 | return Claims{}, ErrMalformed |
| 214 | } |
| 215 | if header.Alg != "HS256" || header.Typ != "JWT" { |
| 216 | return Claims{}, ErrUnsupportedHeader |
| 217 | } |
| 218 | |
| 219 | signingInput := parts[0] + "." + parts[1] |
| 220 | gotSig, err := base64.RawURLEncoding.DecodeString(parts[2]) |
| 221 | if err != nil { |
| 222 | return Claims{}, ErrMalformed |
| 223 | } |
| 224 | wantSig := signHS256(s.key, signingInput) |
| 225 | if !hmac.Equal(gotSig, wantSig) { |
| 226 | return Claims{}, ErrInvalidSignature |
| 227 | } |
| 228 | |
| 229 | payloadBytes, err := base64.RawURLEncoding.DecodeString(parts[1]) |
| 230 | if err != nil { |
| 231 | return Claims{}, ErrMalformed |
| 232 | } |
| 233 | var claims Claims |
| 234 | if err := json.Unmarshal(payloadBytes, &claims); err != nil { |
| 235 | return Claims{}, ErrMalformed |
| 236 | } |
| 237 | if err := validateClaims(claims); err != nil { |
| 238 | return Claims{}, err |
| 239 | } |
| 240 | if !s.now().Before(time.Unix(claims.Exp, 0)) { |
| 241 | return Claims{}, ErrExpired |
| 242 | } |
| 243 | return claims, nil |
| 244 | } |
| 245 | |
| 246 | func (s *Signer) sign(claims Claims) (string, error) { |
| 247 | headerJSON, err := json.Marshal(struct { |
| 248 | Alg string `json:"alg"` |
| 249 | Typ string `json:"typ"` |
| 250 | }{Alg: "HS256", Typ: "JWT"}) |
| 251 | if err != nil { |
| 252 | return "", err |
| 253 | } |
| 254 | payloadJSON, err := json.Marshal(claims) |
| 255 | if err != nil { |
| 256 | return "", err |
| 257 | } |
| 258 | header := base64.RawURLEncoding.EncodeToString(headerJSON) |
| 259 | payload := base64.RawURLEncoding.EncodeToString(payloadJSON) |
| 260 | signingInput := header + "." + payload |
| 261 | sig := base64.RawURLEncoding.EncodeToString(signHS256(s.key, signingInput)) |
| 262 | return signingInput + "." + sig, nil |
| 263 | } |
| 264 | |
| 265 | func signHS256(key []byte, signingInput string) []byte { |
| 266 | mac := hmac.New(sha256.New, key) |
| 267 | _, _ = mac.Write([]byte(signingInput)) |
| 268 | return mac.Sum(nil) |
| 269 | } |
| 270 | |
| 271 | func newJTI(r io.Reader) (string, error) { |
| 272 | buf := make([]byte, jtiBytes) |
| 273 | if _, err := io.ReadFull(r, buf); err != nil { |
| 274 | return "", fmt.Errorf("runnerjwt: jti: %w", err) |
| 275 | } |
| 276 | return base64.RawURLEncoding.EncodeToString(buf), nil |
| 277 | } |
| 278 | |
| 279 | func validateClaims(c Claims) error { |
| 280 | if _, err := c.RunnerID(); err != nil { |
| 281 | return err |
| 282 | } |
| 283 | if c.JobID <= 0 || c.RunID <= 0 || c.RepoID <= 0 || c.Exp <= 0 { |
| 284 | return ErrInvalidClaims |
| 285 | } |
| 286 | switch c.Purpose { |
| 287 | case "", PurposeAPI, PurposeCheckout: |
| 288 | default: |
| 289 | return ErrInvalidClaims |
| 290 | } |
| 291 | if len(c.JTI) < 16 || len(c.JTI) > 128 { |
| 292 | return ErrInvalidClaims |
| 293 | } |
| 294 | return nil |
| 295 | } |
| 296 | |
| 297 | func decodeKey(s string) ([]byte, error) { |
| 298 | encodings := []*base64.Encoding{ |
| 299 | base64.StdEncoding, |
| 300 | base64.RawStdEncoding, |
| 301 | base64.URLEncoding, |
| 302 | base64.RawURLEncoding, |
| 303 | } |
| 304 | var lastErr error |
| 305 | for _, enc := range encodings { |
| 306 | raw, err := enc.DecodeString(s) |
| 307 | if err == nil { |
| 308 | return raw, nil |
| 309 | } |
| 310 | lastErr = err |
| 311 | } |
| 312 | return nil, lastErr |
| 313 | } |
| 314 |