| 1 | // SPDX-License-Identifier: AGPL-3.0-or-later |
| 2 | |
| 3 | package config |
| 4 | |
| 5 | import ( |
| 6 | "os" |
| 7 | "path/filepath" |
| 8 | "reflect" |
| 9 | "strings" |
| 10 | "testing" |
| 11 | "time" |
| 12 | ) |
| 13 | |
| 14 | func TestLoad_DefaultsWithToken(t *testing.T) { |
| 15 | t.Parallel() |
| 16 | cfg, err := Load(LoadOptions{ |
| 17 | Environ: []string{"SHITHUB_RUNNER_TOKEN=tok"}, |
| 18 | }) |
| 19 | if err != nil { |
| 20 | t.Fatalf("Load: %v", err) |
| 21 | } |
| 22 | if cfg.Server.BaseURL != "http://127.0.0.1:8080" { |
| 23 | t.Fatalf("BaseURL: %q", cfg.Server.BaseURL) |
| 24 | } |
| 25 | if cfg.Engine.Kind != "docker" { |
| 26 | t.Fatalf("Engine.Kind: %q", cfg.Engine.Kind) |
| 27 | } |
| 28 | if cfg.Engine.Network != "shithub-actions" { |
| 29 | t.Fatalf("Engine.Network: %q", cfg.Engine.Network) |
| 30 | } |
| 31 | if cfg.Engine.SeccompProfile != "/etc/shithubd-runner/seccomp.json" { |
| 32 | t.Fatalf("Engine.SeccompProfile: %q", cfg.Engine.SeccompProfile) |
| 33 | } |
| 34 | if want := []string{"172.30.0.1"}; !reflect.DeepEqual(cfg.Engine.DNSServers, want) { |
| 35 | t.Fatalf("DNSServers: got %#v want %#v", cfg.Engine.DNSServers, want) |
| 36 | } |
| 37 | if cfg.Engine.User != "65534:65534" { |
| 38 | t.Fatalf("Engine.User: %q", cfg.Engine.User) |
| 39 | } |
| 40 | if cfg.Engine.PidsLimit != 512 { |
| 41 | t.Fatalf("Engine.PidsLimit: %d", cfg.Engine.PidsLimit) |
| 42 | } |
| 43 | if want := []string{"api.github.com", "auth.docker.io", "codeload.github.com", "github.com", "objects.githubusercontent.com", "production.cloudflare.docker.com", "registry-1.docker.io", "*.githubusercontent.com"}; !reflect.DeepEqual(cfg.Runner.NetworkAllowlist, want) { |
| 44 | t.Fatalf("NetworkAllowlist: got %#v want %#v", cfg.Runner.NetworkAllowlist, want) |
| 45 | } |
| 46 | if cfg.Runner.PollInterval != 5*time.Second { |
| 47 | t.Fatalf("PollInterval: %v", cfg.Runner.PollInterval) |
| 48 | } |
| 49 | } |
| 50 | |
| 51 | func TestLoad_FileEnvAliasAndFlagsPrecedence(t *testing.T) { |
| 52 | t.Parallel() |
| 53 | dir := t.TempDir() |
| 54 | path := filepath.Join(dir, "config.toml") |
| 55 | body := ` |
| 56 | [server] |
| 57 | base_url = "https://file.example/" |
| 58 | |
| 59 | [runner] |
| 60 | token = "file-token" |
| 61 | labels = ["self-hosted", "linux"] |
| 62 | capacity = 2 |
| 63 | poll_interval = "10s" |
| 64 | workspace_root = "/tmp/file" |
| 65 | workspace_ttl = "12h" |
| 66 | network_allowlist = ["github.com", "*.githubusercontent.com"] |
| 67 | |
| 68 | [engine] |
| 69 | kind = "docker" |
| 70 | default_image = "file-image" |
| 71 | network = "none" |
| 72 | memory = "1g" |
| 73 | cpus = "1" |
| 74 | seccomp_profile = "/file/seccomp.json" |
| 75 | user = "1000:1000" |
| 76 | pids_limit = 64 |
| 77 | dns_servers = ["172.30.0.10"] |
| 78 | ` |
| 79 | if err := os.WriteFile(path, []byte(body), 0o600); err != nil { |
| 80 | t.Fatalf("WriteFile: %v", err) |
| 81 | } |
| 82 | |
| 83 | cfg, err := Load(LoadOptions{ |
| 84 | ConfigPath: path, |
| 85 | Environ: []string{ |
| 86 | "SHITHUB_RUNNER_TOKEN=alias-token", |
| 87 | "SHITHUB_RUNNER_ENGINE__PIDS_LIMIT=256", |
| 88 | "SHITHUB_RUNNER_ENGINE__DNS_SERVERS=172.30.0.11,172.30.0.12", |
| 89 | "SHITHUB_RUNNER_RUNNER__CAPACITY=3", |
| 90 | "SHITHUB_RUNNER_RUNNER__LABELS=self-hosted,linux,x64", |
| 91 | }, |
| 92 | Overrides: map[string]string{ |
| 93 | "server.base_url": "https://flag.example/path/", |
| 94 | "runner.capacity": "4", |
| 95 | "runner.poll_interval": "2s", |
| 96 | "runner.network_allowlist": "api.github.com,github.com", |
| 97 | "engine.seccomp_profile": "/flag/seccomp.json", |
| 98 | "engine.user": "123:456", |
| 99 | }, |
| 100 | }) |
| 101 | if err != nil { |
| 102 | t.Fatalf("Load: %v", err) |
| 103 | } |
| 104 | if cfg.Server.BaseURL != "https://flag.example/path" { |
| 105 | t.Fatalf("BaseURL: %q", cfg.Server.BaseURL) |
| 106 | } |
| 107 | if cfg.Runner.Token != "alias-token" { |
| 108 | t.Fatalf("Token: %q", cfg.Runner.Token) |
| 109 | } |
| 110 | if cfg.Runner.Capacity != 4 { |
| 111 | t.Fatalf("Capacity: %d", cfg.Runner.Capacity) |
| 112 | } |
| 113 | if cfg.Runner.PollInterval != 2*time.Second { |
| 114 | t.Fatalf("PollInterval: %v", cfg.Runner.PollInterval) |
| 115 | } |
| 116 | if want := []string{"self-hosted", "linux", "x64"}; !reflect.DeepEqual(cfg.Runner.Labels, want) { |
| 117 | t.Fatalf("Labels: got %#v want %#v", cfg.Runner.Labels, want) |
| 118 | } |
| 119 | if cfg.Engine.SeccompProfile != "/flag/seccomp.json" { |
| 120 | t.Fatalf("SeccompProfile: %q", cfg.Engine.SeccompProfile) |
| 121 | } |
| 122 | if cfg.Engine.User != "123:456" { |
| 123 | t.Fatalf("User: %q", cfg.Engine.User) |
| 124 | } |
| 125 | if cfg.Engine.PidsLimit != 256 { |
| 126 | t.Fatalf("PidsLimit: %d", cfg.Engine.PidsLimit) |
| 127 | } |
| 128 | if want := []string{"api.github.com", "github.com"}; !reflect.DeepEqual(cfg.Runner.NetworkAllowlist, want) { |
| 129 | t.Fatalf("NetworkAllowlist: got %#v want %#v", cfg.Runner.NetworkAllowlist, want) |
| 130 | } |
| 131 | if want := []string{"172.30.0.11", "172.30.0.12"}; !reflect.DeepEqual(cfg.Engine.DNSServers, want) { |
| 132 | t.Fatalf("DNSServers: got %#v want %#v", cfg.Engine.DNSServers, want) |
| 133 | } |
| 134 | } |
| 135 | |
| 136 | func TestLoad_RequiresToken(t *testing.T) { |
| 137 | t.Parallel() |
| 138 | _, err := Load(LoadOptions{Environ: []string{}}) |
| 139 | if err == nil || !strings.Contains(err.Error(), "runner.token") { |
| 140 | t.Fatalf("Load error: %v", err) |
| 141 | } |
| 142 | } |
| 143 | |
| 144 | func TestValidate_RejectsBadCapacity(t *testing.T) { |
| 145 | t.Parallel() |
| 146 | cfg := Defaults() |
| 147 | cfg.Runner.Token = "tok" |
| 148 | cfg.Runner.Capacity = 0 |
| 149 | if err := Validate(&cfg); err == nil { |
| 150 | t.Fatal("Validate returned nil error") |
| 151 | } |
| 152 | } |
| 153 | |
| 154 | func TestValidate_RejectsBadEngineKind(t *testing.T) { |
| 155 | t.Parallel() |
| 156 | cfg := Defaults() |
| 157 | cfg.Runner.Token = "tok" |
| 158 | cfg.Engine.Kind = "runc" |
| 159 | if err := Validate(&cfg); err == nil { |
| 160 | t.Fatal("Validate returned nil error") |
| 161 | } |
| 162 | } |
| 163 | |
| 164 | func TestValidate_RejectsBadPidsLimit(t *testing.T) { |
| 165 | t.Parallel() |
| 166 | cfg := Defaults() |
| 167 | cfg.Runner.Token = "tok" |
| 168 | cfg.Engine.PidsLimit = 0 |
| 169 | if err := Validate(&cfg); err == nil { |
| 170 | t.Fatal("Validate returned nil error") |
| 171 | } |
| 172 | } |
| 173 | |
| 174 | func TestValidate_RejectsBadNetworkAllowlist(t *testing.T) { |
| 175 | t.Parallel() |
| 176 | cfg := Defaults() |
| 177 | cfg.Runner.Token = "tok" |
| 178 | cfg.Runner.NetworkAllowlist = []string{"https://github.com"} |
| 179 | if err := Validate(&cfg); err == nil { |
| 180 | t.Fatal("Validate returned nil error") |
| 181 | } |
| 182 | } |
| 183 | |
| 184 | func TestValidate_RejectsBadDNSServer(t *testing.T) { |
| 185 | t.Parallel() |
| 186 | cfg := Defaults() |
| 187 | cfg.Runner.Token = "tok" |
| 188 | cfg.Engine.DNSServers = []string{"dns.internal"} |
| 189 | if err := Validate(&cfg); err == nil { |
| 190 | t.Fatal("Validate returned nil error") |
| 191 | } |
| 192 | } |
| 193 |