Go · 4930 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 package runnerjwt_test
4
5 import (
6 "encoding/base64"
7 "errors"
8 "strings"
9 "testing"
10 "time"
11
12 "github.com/tenseleyFlow/shithub/internal/auth/runnerjwt"
13 )
14
15 func TestDeriveKeyFromTOTPKeyB64_UsesHKDFLabel(t *testing.T) {
16 raw := bytesOf(0x42, 32)
17 derived, err := runnerjwt.DeriveKeyFromTOTPKeyB64(base64.StdEncoding.EncodeToString(raw))
18 if err != nil {
19 t.Fatalf("DeriveKeyFromTOTPKeyB64: %v", err)
20 }
21 if len(derived) != 32 {
22 t.Fatalf("derived length: got %d, want 32", len(derived))
23 }
24 if string(derived) == string(raw) {
25 t.Fatal("derived key matched raw TOTP key; want HKDF isolation")
26 }
27 }
28
29 func TestMintVerifyRoundTrip(t *testing.T) {
30 now := time.Date(2026, 5, 10, 12, 0, 0, 0, time.UTC)
31 signer := newTestSigner(t, now, bytesOf(0x11, 32))
32
33 token, claims, err := signer.Mint(runnerjwt.MintParams{
34 RunnerID: 7,
35 JobID: 11,
36 RunID: 13,
37 RepoID: 17,
38 })
39 if err != nil {
40 t.Fatalf("Mint: %v", err)
41 }
42 if got := len(strings.Split(token, ".")); got != 3 {
43 t.Fatalf("token parts: got %d, want 3", got)
44 }
45 if claims.Exp != now.Add(runnerjwt.DefaultTTL).Unix() {
46 t.Fatalf("exp: got %d, want %d", claims.Exp, now.Add(runnerjwt.DefaultTTL).Unix())
47 }
48 runnerID, err := claims.RunnerID()
49 if err != nil {
50 t.Fatalf("RunnerID: %v", err)
51 }
52 if runnerID != 7 {
53 t.Fatalf("RunnerID: got %d, want 7", runnerID)
54 }
55
56 got, err := signer.Verify(token)
57 if err != nil {
58 t.Fatalf("Verify: %v", err)
59 }
60 if got != claims {
61 t.Fatalf("claims mismatch:\n got %#v\nwant %#v", got, claims)
62 }
63 }
64
65 func TestVerifyRejectsTamperedPayload(t *testing.T) {
66 signer := newTestSigner(t, time.Unix(100, 0), bytesOf(0x22, 32))
67 token, _, err := signer.Mint(runnerjwt.MintParams{RunnerID: 1, JobID: 2, RunID: 3, RepoID: 4})
68 if err != nil {
69 t.Fatalf("Mint: %v", err)
70 }
71 parts := strings.Split(token, ".")
72 replacement := "A"
73 if parts[1][len(parts[1])-1] == 'A' {
74 replacement = "B"
75 }
76 parts[1] = parts[1][:len(parts[1])-1] + replacement
77
78 if _, err := signer.Verify(strings.Join(parts, ".")); !errors.Is(err, runnerjwt.ErrInvalidSignature) {
79 t.Fatalf("Verify tampered payload: got %v, want ErrInvalidSignature", err)
80 }
81 }
82
83 func TestVerifyRejectsExpiredToken(t *testing.T) {
84 issuedAt := time.Unix(100, 0)
85 key, err := runnerjwt.DeriveKeyFromTOTPKeyB64(base64.StdEncoding.EncodeToString(bytesOf(0x99, 32)))
86 if err != nil {
87 t.Fatalf("derive: %v", err)
88 }
89 signer, err := runnerjwt.NewFromKey(
90 key,
91 runnerjwt.WithClock(func() time.Time { return issuedAt }),
92 runnerjwt.WithRand(strings.NewReader(string(bytesOf(0x55, 32)))),
93 )
94 if err != nil {
95 t.Fatalf("NewFromKey signer: %v", err)
96 }
97 token, _, err := signer.Mint(runnerjwt.MintParams{RunnerID: 1, JobID: 2, RunID: 3, RepoID: 4, TTL: time.Second})
98 if err != nil {
99 t.Fatalf("Mint: %v", err)
100 }
101 verifier, err := runnerjwt.NewFromKey(key, runnerjwt.WithClock(func() time.Time { return issuedAt.Add(time.Second) }))
102 if err != nil {
103 t.Fatalf("NewFromKey verifier: %v", err)
104 }
105 if _, err := verifier.Verify(token); !errors.Is(err, runnerjwt.ErrExpired) {
106 t.Fatalf("Verify expired: got %v, want ErrExpired", err)
107 }
108 }
109
110 func TestVerifyRejectsUnsupportedHeader(t *testing.T) {
111 signer := newTestSigner(t, time.Unix(100, 0), bytesOf(0x44, 32))
112 if _, err := signer.Verify("eyJhbGciOiJub25lIiwidHlwIjoiSldUIn0.e30.sig"); !errors.Is(err, runnerjwt.ErrUnsupportedHeader) {
113 t.Fatalf("Verify unsupported header: got %v, want ErrUnsupportedHeader", err)
114 }
115 }
116
117 func TestMintGeneratesDistinctJTI(t *testing.T) {
118 now := time.Unix(100, 0)
119 rng := strings.NewReader(string(append(bytesOf(0x01, 32), bytesOf(0x02, 32)...)))
120 key, err := runnerjwt.DeriveKeyFromTOTPKeyB64(base64.StdEncoding.EncodeToString(bytesOf(0x77, 32)))
121 if err != nil {
122 t.Fatalf("derive: %v", err)
123 }
124 signer, err := runnerjwt.NewFromKey(key, runnerjwt.WithClock(func() time.Time { return now }), runnerjwt.WithRand(rng))
125 if err != nil {
126 t.Fatalf("NewFromKey: %v", err)
127 }
128 _, first, err := signer.Mint(runnerjwt.MintParams{RunnerID: 1, JobID: 2, RunID: 3, RepoID: 4})
129 if err != nil {
130 t.Fatalf("Mint first: %v", err)
131 }
132 _, second, err := signer.Mint(runnerjwt.MintParams{RunnerID: 1, JobID: 2, RunID: 3, RepoID: 4})
133 if err != nil {
134 t.Fatalf("Mint second: %v", err)
135 }
136 if first.JTI == second.JTI {
137 t.Fatalf("JTI reused: %s", first.JTI)
138 }
139 }
140
141 func newTestSigner(t *testing.T, now time.Time, jtiBytes []byte) *runnerjwt.Signer {
142 t.Helper()
143 key, err := runnerjwt.DeriveKeyFromTOTPKeyB64(base64.StdEncoding.EncodeToString(bytesOf(0x99, 32)))
144 if err != nil {
145 t.Fatalf("derive: %v", err)
146 }
147 signer, err := runnerjwt.NewFromKey(
148 key,
149 runnerjwt.WithClock(func() time.Time { return now }),
150 runnerjwt.WithRand(strings.NewReader(string(jtiBytes))),
151 )
152 if err != nil {
153 t.Fatalf("NewFromKey: %v", err)
154 }
155 return signer
156 }
157
158 func bytesOf(b byte, n int) []byte {
159 out := make([]byte, n)
160 for i := range out {
161 out[i] = b
162 }
163 return out
164 }
165