Go · 5742 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 package runnerjwt_test
4
5 import (
6 "encoding/base64"
7 "errors"
8 "strings"
9 "testing"
10 "time"
11
12 "github.com/tenseleyFlow/shithub/internal/auth/runnerjwt"
13 )
14
15 func TestDeriveKeyFromTOTPKeyB64_UsesHKDFLabel(t *testing.T) {
16 raw := bytesOf(0x42, 32)
17 derived, err := runnerjwt.DeriveKeyFromTOTPKeyB64(base64.StdEncoding.EncodeToString(raw))
18 if err != nil {
19 t.Fatalf("DeriveKeyFromTOTPKeyB64: %v", err)
20 }
21 if len(derived) != 32 {
22 t.Fatalf("derived length: got %d, want 32", len(derived))
23 }
24 if string(derived) == string(raw) {
25 t.Fatal("derived key matched raw TOTP key; want HKDF isolation")
26 }
27 }
28
29 func TestMintVerifyRoundTrip(t *testing.T) {
30 now := time.Date(2026, 5, 10, 12, 0, 0, 0, time.UTC)
31 signer := newTestSigner(t, now, bytesOf(0x11, 32))
32
33 token, claims, err := signer.Mint(runnerjwt.MintParams{
34 RunnerID: 7,
35 JobID: 11,
36 RunID: 13,
37 RepoID: 17,
38 })
39 if err != nil {
40 t.Fatalf("Mint: %v", err)
41 }
42 if got := len(strings.Split(token, ".")); got != 3 {
43 t.Fatalf("token parts: got %d, want 3", got)
44 }
45 if claims.Exp != now.Add(runnerjwt.DefaultTTL).Unix() {
46 t.Fatalf("exp: got %d, want %d", claims.Exp, now.Add(runnerjwt.DefaultTTL).Unix())
47 }
48 if claims.Purpose != runnerjwt.PurposeAPI {
49 t.Fatalf("purpose: got %q, want %q", claims.Purpose, runnerjwt.PurposeAPI)
50 }
51 runnerID, err := claims.RunnerID()
52 if err != nil {
53 t.Fatalf("RunnerID: %v", err)
54 }
55 if runnerID != 7 {
56 t.Fatalf("RunnerID: got %d, want 7", runnerID)
57 }
58
59 got, err := signer.Verify(token)
60 if err != nil {
61 t.Fatalf("Verify: %v", err)
62 }
63 if got != claims {
64 t.Fatalf("claims mismatch:\n got %#v\nwant %#v", got, claims)
65 }
66 }
67
68 func TestMintVerifyCheckoutPurpose(t *testing.T) {
69 now := time.Date(2026, 5, 10, 12, 0, 0, 0, time.UTC)
70 signer := newTestSigner(t, now, bytesOf(0x66, 32))
71
72 token, claims, err := signer.Mint(runnerjwt.MintParams{
73 RunnerID: 7,
74 JobID: 11,
75 RunID: 13,
76 RepoID: 17,
77 Purpose: runnerjwt.PurposeCheckout,
78 })
79 if err != nil {
80 t.Fatalf("Mint: %v", err)
81 }
82 if claims.Purpose != runnerjwt.PurposeCheckout {
83 t.Fatalf("purpose: got %q, want checkout", claims.Purpose)
84 }
85 got, err := signer.Verify(token)
86 if err != nil {
87 t.Fatalf("Verify: %v", err)
88 }
89 if got.Purpose != runnerjwt.PurposeCheckout {
90 t.Fatalf("verified purpose: got %q, want checkout", got.Purpose)
91 }
92 }
93
94 func TestVerifyRejectsTamperedPayload(t *testing.T) {
95 signer := newTestSigner(t, time.Unix(100, 0), bytesOf(0x22, 32))
96 token, _, err := signer.Mint(runnerjwt.MintParams{RunnerID: 1, JobID: 2, RunID: 3, RepoID: 4})
97 if err != nil {
98 t.Fatalf("Mint: %v", err)
99 }
100 parts := strings.Split(token, ".")
101 replacement := "A"
102 if parts[1][len(parts[1])-1] == 'A' {
103 replacement = "B"
104 }
105 parts[1] = parts[1][:len(parts[1])-1] + replacement
106
107 if _, err := signer.Verify(strings.Join(parts, ".")); !errors.Is(err, runnerjwt.ErrInvalidSignature) {
108 t.Fatalf("Verify tampered payload: got %v, want ErrInvalidSignature", err)
109 }
110 }
111
112 func TestVerifyRejectsExpiredToken(t *testing.T) {
113 issuedAt := time.Unix(100, 0)
114 key, err := runnerjwt.DeriveKeyFromTOTPKeyB64(base64.StdEncoding.EncodeToString(bytesOf(0x99, 32)))
115 if err != nil {
116 t.Fatalf("derive: %v", err)
117 }
118 signer, err := runnerjwt.NewFromKey(
119 key,
120 runnerjwt.WithClock(func() time.Time { return issuedAt }),
121 runnerjwt.WithRand(strings.NewReader(string(bytesOf(0x55, 32)))),
122 )
123 if err != nil {
124 t.Fatalf("NewFromKey signer: %v", err)
125 }
126 token, _, err := signer.Mint(runnerjwt.MintParams{RunnerID: 1, JobID: 2, RunID: 3, RepoID: 4, TTL: time.Second})
127 if err != nil {
128 t.Fatalf("Mint: %v", err)
129 }
130 verifier, err := runnerjwt.NewFromKey(key, runnerjwt.WithClock(func() time.Time { return issuedAt.Add(time.Second) }))
131 if err != nil {
132 t.Fatalf("NewFromKey verifier: %v", err)
133 }
134 if _, err := verifier.Verify(token); !errors.Is(err, runnerjwt.ErrExpired) {
135 t.Fatalf("Verify expired: got %v, want ErrExpired", err)
136 }
137 }
138
139 func TestVerifyRejectsUnsupportedHeader(t *testing.T) {
140 signer := newTestSigner(t, time.Unix(100, 0), bytesOf(0x44, 32))
141 if _, err := signer.Verify("eyJhbGciOiJub25lIiwidHlwIjoiSldUIn0.e30.sig"); !errors.Is(err, runnerjwt.ErrUnsupportedHeader) {
142 t.Fatalf("Verify unsupported header: got %v, want ErrUnsupportedHeader", err)
143 }
144 }
145
146 func TestMintGeneratesDistinctJTI(t *testing.T) {
147 now := time.Unix(100, 0)
148 rng := strings.NewReader(string(append(bytesOf(0x01, 32), bytesOf(0x02, 32)...)))
149 key, err := runnerjwt.DeriveKeyFromTOTPKeyB64(base64.StdEncoding.EncodeToString(bytesOf(0x77, 32)))
150 if err != nil {
151 t.Fatalf("derive: %v", err)
152 }
153 signer, err := runnerjwt.NewFromKey(key, runnerjwt.WithClock(func() time.Time { return now }), runnerjwt.WithRand(rng))
154 if err != nil {
155 t.Fatalf("NewFromKey: %v", err)
156 }
157 _, first, err := signer.Mint(runnerjwt.MintParams{RunnerID: 1, JobID: 2, RunID: 3, RepoID: 4})
158 if err != nil {
159 t.Fatalf("Mint first: %v", err)
160 }
161 _, second, err := signer.Mint(runnerjwt.MintParams{RunnerID: 1, JobID: 2, RunID: 3, RepoID: 4})
162 if err != nil {
163 t.Fatalf("Mint second: %v", err)
164 }
165 if first.JTI == second.JTI {
166 t.Fatalf("JTI reused: %s", first.JTI)
167 }
168 }
169
170 func newTestSigner(t *testing.T, now time.Time, jtiBytes []byte) *runnerjwt.Signer {
171 t.Helper()
172 key, err := runnerjwt.DeriveKeyFromTOTPKeyB64(base64.StdEncoding.EncodeToString(bytesOf(0x99, 32)))
173 if err != nil {
174 t.Fatalf("derive: %v", err)
175 }
176 signer, err := runnerjwt.NewFromKey(
177 key,
178 runnerjwt.WithClock(func() time.Time { return now }),
179 runnerjwt.WithRand(strings.NewReader(string(jtiBytes))),
180 )
181 if err != nil {
182 t.Fatalf("NewFromKey: %v", err)
183 }
184 return signer
185 }
186
187 func bytesOf(b byte, n int) []byte {
188 out := make([]byte, n)
189 for i := range out {
190 out[i] = b
191 }
192 return out
193 }
194