Go · 12727 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 // Package config owns the layered configuration loader for shithubd-runner.
4 //
5 // Precedence, lowest to highest:
6 // 1. built-in defaults
7 // 2. TOML file (/etc/shithubd-runner/config.toml, SHITHUB_RUNNER_CONFIG, or --config)
8 // 3. environment variables with SHITHUB_RUNNER_ prefix
9 // 4. CLI flag overrides handed in by the caller
10 package config
11
12 import (
13 "errors"
14 "fmt"
15 "net/netip"
16 "net/url"
17 "os"
18 "reflect"
19 "strconv"
20 "strings"
21 "time"
22
23 "github.com/BurntSushi/toml"
24
25 "github.com/tenseleyFlow/shithub/internal/actions/runnerlabels"
26 )
27
28 const (
29 DefaultPath = "/etc/shithubd-runner/config.toml"
30 EnvPrefix = "SHITHUB_RUNNER_"
31 defaultImage = "ghcr.io/shithub/runner-nix:1.0"
32 defaultNetwork = "shithub-actions"
33 defaultDNSServer = "172.30.0.1"
34 defaultSeccompProfile = "/etc/shithubd-runner/seccomp.json"
35 defaultContainerUser = "65534:65534"
36 defaultContainerPIDMax = 512
37 )
38
39 var defaultNetworkAllowlist = []string{
40 "api.github.com",
41 "auth.docker.io",
42 "codeload.github.com",
43 "github.com",
44 "objects.githubusercontent.com",
45 "production.cloudflare.docker.com",
46 "registry-1.docker.io",
47 "*.githubusercontent.com",
48 }
49
50 // LoadOptions controls config resolution. Zero value uses the default path,
51 // process environment, and no CLI overrides.
52 type LoadOptions struct {
53 ConfigPath string
54 Overrides map[string]string
55 Environ []string
56 }
57
58 // Config is the typed root consumed by cmd/shithubd-runner.
59 type Config struct {
60 Server ServerConfig `toml:"server"`
61 Runner RunnerConfig `toml:"runner"`
62 Engine EngineConfig `toml:"engine"`
63 Log LogConfig `toml:"log"`
64 }
65
66 type ServerConfig struct {
67 BaseURL string `toml:"base_url"`
68 }
69
70 type RunnerConfig struct {
71 Token string `toml:"token"`
72 Labels []string `toml:"labels"`
73 Capacity int `toml:"capacity"`
74 PollInterval time.Duration `toml:"poll_interval"`
75 WorkspaceRoot string `toml:"workspace_root"`
76 WorkspaceTTL time.Duration `toml:"workspace_ttl"`
77 NetworkAllowlist []string `toml:"network_allowlist"`
78 }
79
80 type EngineConfig struct {
81 Kind string `toml:"kind"`
82 DefaultImage string `toml:"default_image"`
83 Network string `toml:"network"`
84 Memory string `toml:"memory"`
85 CPUs string `toml:"cpus"`
86 SeccompProfile string `toml:"seccomp_profile"`
87 User string `toml:"user"`
88 PidsLimit int `toml:"pids_limit"`
89 DNSServers []string `toml:"dns_servers"`
90 }
91
92 type LogConfig struct {
93 Level string `toml:"level"`
94 Format string `toml:"format"`
95 }
96
97 func Defaults() Config {
98 return Config{
99 Server: ServerConfig{
100 BaseURL: "http://127.0.0.1:8080",
101 },
102 Runner: RunnerConfig{
103 Labels: runnerlabels.DefaultShared(),
104 Capacity: 1,
105 PollInterval: 5 * time.Second,
106 WorkspaceRoot: "/var/lib/shithubd-runner/workspaces",
107 WorkspaceTTL: 24 * time.Hour,
108 NetworkAllowlist: append([]string{}, defaultNetworkAllowlist...),
109 },
110 Engine: EngineConfig{
111 Kind: "docker",
112 DefaultImage: defaultImage,
113 Network: defaultNetwork,
114 Memory: "2g",
115 CPUs: "2",
116 SeccompProfile: defaultSeccompProfile,
117 User: defaultContainerUser,
118 PidsLimit: defaultContainerPIDMax,
119 DNSServers: []string{defaultDNSServer},
120 },
121 Log: LogConfig{
122 Level: "info",
123 Format: "text",
124 },
125 }
126 }
127
128 func Load(opts LoadOptions) (Config, error) {
129 cfg := Defaults()
130 environ := opts.Environ
131 if environ == nil {
132 environ = os.Environ()
133 }
134 if err := mergeFile(&cfg, configPath(opts.ConfigPath, environ)); err != nil {
135 return cfg, err
136 }
137 if err := mergeEnv(&cfg, environ); err != nil {
138 return cfg, err
139 }
140 applyAliases(&cfg, environ)
141 if err := mergeFlags(&cfg, opts.Overrides); err != nil {
142 return cfg, err
143 }
144 if err := Validate(&cfg); err != nil {
145 return cfg, err
146 }
147 return cfg, nil
148 }
149
150 func configPath(flagPath string, environ []string) string {
151 if strings.TrimSpace(flagPath) != "" {
152 return strings.TrimSpace(flagPath)
153 }
154 if v := envLookup(environ, EnvPrefix+"CONFIG"); v != "" {
155 return v
156 }
157 return DefaultPath
158 }
159
160 func mergeFile(cfg *Config, path string) error {
161 body, err := os.ReadFile(path) //nolint:gosec // operator-supplied config path
162 if err != nil {
163 if errors.Is(err, os.ErrNotExist) && path == DefaultPath {
164 return nil
165 }
166 return fmt.Errorf("runner config: read %s: %w", path, err)
167 }
168 if _, err := toml.Decode(string(body), cfg); err != nil {
169 return fmt.Errorf("runner config: parse %s: %w", path, err)
170 }
171 return nil
172 }
173
174 func mergeEnv(cfg *Config, environ []string) error {
175 src := make(map[string]string)
176 for _, kv := range environ {
177 if !strings.HasPrefix(kv, EnvPrefix) {
178 continue
179 }
180 eq := strings.IndexByte(kv, '=')
181 if eq < 0 {
182 continue
183 }
184 key := strings.TrimPrefix(kv[:eq], EnvPrefix)
185 src[key] = kv[eq+1:]
186 }
187 return walkAndApply(reflect.ValueOf(cfg).Elem(), reflect.TypeOf(*cfg), "", src)
188 }
189
190 func applyAliases(cfg *Config, environ []string) {
191 if v := envLookup(environ, EnvPrefix+"URL"); v != "" && envLookup(environ, EnvPrefix+"SERVER__BASE_URL") == "" {
192 cfg.Server.BaseURL = v
193 }
194 if v := envLookup(environ, EnvPrefix+"TOKEN"); v != "" && envLookup(environ, EnvPrefix+"RUNNER__TOKEN") == "" {
195 cfg.Runner.Token = v
196 }
197 if v := envLookup(environ, EnvPrefix+"LABELS"); v != "" && envLookup(environ, EnvPrefix+"RUNNER__LABELS") == "" {
198 cfg.Runner.Labels = []string{v}
199 }
200 }
201
202 func mergeFlags(cfg *Config, overrides map[string]string) error {
203 if len(overrides) == 0 {
204 return nil
205 }
206 src := make(map[string]string, len(overrides))
207 for k, v := range overrides {
208 k = strings.TrimSpace(k)
209 if k == "" {
210 continue
211 }
212 src[strings.ToUpper(strings.ReplaceAll(k, ".", "__"))] = v
213 }
214 return walkAndApply(reflect.ValueOf(cfg).Elem(), reflect.TypeOf(*cfg), "", src)
215 }
216
217 func Validate(c *Config) error {
218 c.Server.BaseURL = strings.TrimSpace(c.Server.BaseURL)
219 if c.Server.BaseURL == "" {
220 return errors.New("runner config: server.base_url is required")
221 }
222 u, err := url.Parse(c.Server.BaseURL)
223 if err != nil || u.Scheme == "" || u.Host == "" || (u.Scheme != "http" && u.Scheme != "https") {
224 return fmt.Errorf("runner config: server.base_url must be an absolute http(s) URL, got %q", c.Server.BaseURL)
225 }
226 c.Server.BaseURL = strings.TrimRight(c.Server.BaseURL, "/")
227
228 if strings.TrimSpace(c.Runner.Token) == "" {
229 return errors.New("runner config: runner.token is required")
230 }
231 labels, err := normalizeLabels(c.Runner.Labels)
232 if err != nil {
233 return fmt.Errorf("runner config: runner.labels: %w", err)
234 }
235 c.Runner.Labels = labels
236 if c.Runner.Capacity < 1 || c.Runner.Capacity > 64 {
237 return fmt.Errorf("runner config: runner.capacity must be between 1 and 64, got %d", c.Runner.Capacity)
238 }
239 if c.Runner.PollInterval <= 0 {
240 return errors.New("runner config: runner.poll_interval must be positive")
241 }
242 if strings.TrimSpace(c.Runner.WorkspaceRoot) == "" {
243 return errors.New("runner config: runner.workspace_root is required")
244 }
245 if c.Runner.WorkspaceTTL <= 0 {
246 return errors.New("runner config: runner.workspace_ttl must be positive")
247 }
248 allowlist, err := normalizeHostPatterns(c.Runner.NetworkAllowlist)
249 if err != nil {
250 return fmt.Errorf("runner config: runner.network_allowlist: %w", err)
251 }
252 c.Runner.NetworkAllowlist = allowlist
253
254 switch strings.ToLower(strings.TrimSpace(c.Engine.Kind)) {
255 case "docker", "podman":
256 c.Engine.Kind = strings.ToLower(strings.TrimSpace(c.Engine.Kind))
257 default:
258 return fmt.Errorf("runner config: engine.kind must be docker|podman, got %q", c.Engine.Kind)
259 }
260 if strings.TrimSpace(c.Engine.DefaultImage) == "" {
261 return errors.New("runner config: engine.default_image is required")
262 }
263 if strings.TrimSpace(c.Engine.Network) == "" {
264 return errors.New("runner config: engine.network is required")
265 }
266 if strings.TrimSpace(c.Engine.Memory) == "" {
267 return errors.New("runner config: engine.memory is required")
268 }
269 if strings.TrimSpace(c.Engine.CPUs) == "" {
270 return errors.New("runner config: engine.cpus is required")
271 }
272 c.Engine.SeccompProfile = strings.TrimSpace(c.Engine.SeccompProfile)
273 if c.Engine.SeccompProfile == "" {
274 return errors.New("runner config: engine.seccomp_profile is required")
275 }
276 c.Engine.User = strings.TrimSpace(c.Engine.User)
277 if c.Engine.User == "" {
278 return errors.New("runner config: engine.user is required")
279 }
280 if c.Engine.PidsLimit <= 0 {
281 return fmt.Errorf("runner config: engine.pids_limit must be positive, got %d", c.Engine.PidsLimit)
282 }
283 dnsServers, err := normalizeDNSServers(c.Engine.DNSServers)
284 if err != nil {
285 return fmt.Errorf("runner config: engine.dns_servers: %w", err)
286 }
287 c.Engine.DNSServers = dnsServers
288
289 switch strings.ToLower(c.Log.Level) {
290 case "debug", "info", "warn", "error":
291 c.Log.Level = strings.ToLower(c.Log.Level)
292 default:
293 return fmt.Errorf("runner config: log.level must be debug|info|warn|error, got %q", c.Log.Level)
294 }
295 switch strings.ToLower(c.Log.Format) {
296 case "text", "json":
297 c.Log.Format = strings.ToLower(c.Log.Format)
298 default:
299 return fmt.Errorf("runner config: log.format must be text|json, got %q", c.Log.Format)
300 }
301 return nil
302 }
303
304 func normalizeHostPatterns(patterns []string) ([]string, error) {
305 seen := map[string]struct{}{}
306 out := make([]string, 0, len(patterns))
307 for _, p := range patterns {
308 p = strings.ToLower(strings.TrimSpace(p))
309 if p == "" {
310 continue
311 }
312 if strings.ContainsAny(p, "/:") || strings.Trim(p, "*.abcdefghijklmnopqrstuvwxyz0123456789-") != "" {
313 return nil, fmt.Errorf("invalid host pattern %q", p)
314 }
315 if strings.Contains(p, "**") || strings.Contains(p, "..") || strings.HasPrefix(p, ".") || strings.HasSuffix(p, ".") {
316 return nil, fmt.Errorf("invalid host pattern %q", p)
317 }
318 if strings.Contains(p, "*") && !strings.HasPrefix(p, "*.") {
319 return nil, fmt.Errorf("invalid wildcard host pattern %q", p)
320 }
321 if _, ok := seen[p]; ok {
322 continue
323 }
324 seen[p] = struct{}{}
325 out = append(out, p)
326 }
327 if len(out) == 0 {
328 return nil, errors.New("must contain at least one host pattern")
329 }
330 return out, nil
331 }
332
333 func normalizeDNSServers(servers []string) ([]string, error) {
334 seen := map[string]struct{}{}
335 out := make([]string, 0, len(servers))
336 for _, s := range servers {
337 s = strings.TrimSpace(s)
338 if s == "" {
339 continue
340 }
341 if strings.ContainsAny(s, " \t\r\n") {
342 return nil, fmt.Errorf("invalid DNS server %q", s)
343 }
344 if _, err := netip.ParseAddr(s); err != nil {
345 return nil, fmt.Errorf("invalid DNS server %q", s)
346 }
347 if _, ok := seen[s]; ok {
348 continue
349 }
350 seen[s] = struct{}{}
351 out = append(out, s)
352 }
353 return out, nil
354 }
355
356 func normalizeLabels(labels []string) ([]string, error) {
357 if len(labels) == 1 && strings.Contains(labels[0], ",") {
358 return runnerlabels.ParseCSV(labels[0])
359 }
360 return runnerlabels.Normalize(labels)
361 }
362
363 func envLookup(environ []string, key string) string {
364 prefix := key + "="
365 for _, kv := range environ {
366 if strings.HasPrefix(kv, prefix) {
367 return kv[len(prefix):]
368 }
369 }
370 return ""
371 }
372
373 func walkAndApply(v reflect.Value, t reflect.Type, prefix string, src map[string]string) error {
374 for i := 0; i < t.NumField(); i++ {
375 field := t.Field(i)
376 tag := field.Tag.Get("toml")
377 if tag == "" || tag == "-" {
378 continue
379 }
380 fieldPath := strings.ToUpper(tag)
381 if prefix != "" {
382 fieldPath = prefix + "__" + fieldPath
383 }
384 fv := v.Field(i)
385 if field.Type.Kind() == reflect.Struct && field.Type != reflect.TypeOf(time.Duration(0)) {
386 if err := walkAndApply(fv, field.Type, fieldPath, src); err != nil {
387 return err
388 }
389 continue
390 }
391 raw, ok := src[fieldPath]
392 if !ok {
393 continue
394 }
395 if err := setField(fv, field.Type, raw); err != nil {
396 return fmt.Errorf("runner config: %s: %w", strings.ReplaceAll(strings.ToLower(fieldPath), "__", "."), err)
397 }
398 }
399 return nil
400 }
401
402 func setField(v reflect.Value, t reflect.Type, raw string) error {
403 switch t.Kind() {
404 case reflect.String:
405 v.SetString(raw)
406 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
407 if t == reflect.TypeOf(time.Duration(0)) {
408 d, err := time.ParseDuration(raw)
409 if err != nil {
410 return fmt.Errorf("invalid duration: %w", err)
411 }
412 v.SetInt(int64(d))
413 return nil
414 }
415 n, err := strconv.ParseInt(raw, 10, 64)
416 if err != nil {
417 return fmt.Errorf("invalid int: %w", err)
418 }
419 v.SetInt(n)
420 case reflect.Slice:
421 if t.Elem().Kind() != reflect.String {
422 return fmt.Errorf("unsupported slice type %s", t)
423 }
424 parts, err := runnerlabels.ParseCSV(raw)
425 if err != nil {
426 return err
427 }
428 v.Set(reflect.ValueOf(parts))
429 default:
430 return fmt.Errorf("unsupported field kind %s", t.Kind())
431 }
432 return nil
433 }
434