tenseleyflow/shithub / fdc8d5b

Browse files

Add layered config loader (defaults / TOML file / env / flags) with secret redaction

Authored by mfwolffe <wolffemf@dukes.jmu.edu>
SHA
fdc8d5bf6da3d2d47c6c897276aa73eb0ca2122b
Parents
58257f9
Tree
05eb5af

3 changed files

StatusFile+-
A internal/infra/config/config.go 313 0
A internal/infra/config/config_test.go 110 0
A internal/infra/config/redact.go 72 0
internal/infra/config/config.goadded
@@ -0,0 +1,313 @@
1
+// SPDX-License-Identifier: AGPL-3.0-or-later
2
+
3
+// Package config owns the layered configuration loader.
4
+//
5
+// Precedence (lowest → highest):
6
+//
7
+//  1. Built-in defaults (this file)
8
+//  2. TOML file (path from $SHITHUB_CONFIG, falling back to /etc/shithub/config.toml)
9
+//  3. Environment variables (SHITHUB_<area>_<key>; nested keys joined with "__")
10
+//  4. CLI flag overrides handed in by the caller
11
+//
12
+// The Config struct is the single source of truth for what is configurable.
13
+// Every consumer reads from this struct rather than from os.Getenv directly.
14
+package config
15
+
16
+import (
17
+	"errors"
18
+	"fmt"
19
+	"os"
20
+	"reflect"
21
+	"strconv"
22
+	"strings"
23
+	"time"
24
+
25
+	"github.com/BurntSushi/toml"
26
+)
27
+
28
+// Config is the typed root.
29
+type Config struct {
30
+	Env            string               `toml:"env"` // "dev", "staging", "prod"
31
+	Web            WebConfig            `toml:"web"`
32
+	DB             DBConfig             `toml:"db"`
33
+	Log            LogConfig            `toml:"log"`
34
+	Metrics        MetricsConfig        `toml:"metrics"`
35
+	Tracing        TracingConfig        `toml:"tracing"`
36
+	ErrorReporting ErrorReportingConfig `toml:"error_reporting"`
37
+	Session        SessionConfig        `toml:"session"`
38
+}
39
+
40
+// WebConfig holds HTTP server settings.
41
+type WebConfig struct {
42
+	Addr            string        `toml:"addr"`
43
+	ReadTimeout     time.Duration `toml:"read_timeout"`
44
+	WriteTimeout    time.Duration `toml:"write_timeout"`
45
+	ShutdownTimeout time.Duration `toml:"shutdown_timeout"`
46
+}
47
+
48
+// DBConfig holds Postgres settings.
49
+type DBConfig struct {
50
+	URL            string        `toml:"url"`
51
+	MaxConns       int           `toml:"max_conns"`
52
+	MinConns       int           `toml:"min_conns"`
53
+	ConnectTimeout time.Duration `toml:"connect_timeout"`
54
+}
55
+
56
+// LogConfig holds slog settings.
57
+type LogConfig struct {
58
+	Level  string `toml:"level"`  // debug | info | warn | error
59
+	Format string `toml:"format"` // text (dev) | json (prod)
60
+}
61
+
62
+// MetricsConfig configures the /metrics endpoint.
63
+type MetricsConfig struct {
64
+	Enabled       bool   `toml:"enabled"`
65
+	BasicAuthUser string `toml:"basic_auth_user"`
66
+	BasicAuthPass string `toml:"basic_auth_pass"`
67
+}
68
+
69
+// TracingConfig configures the OpenTelemetry exporter.
70
+type TracingConfig struct {
71
+	Enabled     bool    `toml:"enabled"`
72
+	Endpoint    string  `toml:"endpoint"` // OTLP HTTP endpoint
73
+	SampleRate  float64 `toml:"sample_rate"`
74
+	ServiceName string  `toml:"service_name"`
75
+}
76
+
77
+// ErrorReportingConfig configures the Sentry/GlitchTip-protocol DSN.
78
+type ErrorReportingConfig struct {
79
+	DSN         string `toml:"dsn"`
80
+	Environment string `toml:"environment"`
81
+	Release     string `toml:"release"`
82
+}
83
+
84
+// SessionConfig configures the cookie session store.
85
+type SessionConfig struct {
86
+	KeyB64 string        `toml:"key_b64"`
87
+	MaxAge time.Duration `toml:"max_age"`
88
+	Secure bool          `toml:"secure"`
89
+}
90
+
91
+// Defaults returns the zero-config baseline.
92
+func Defaults() Config {
93
+	return Config{
94
+		Env: "dev",
95
+		Web: WebConfig{
96
+			Addr:            ":8080",
97
+			ReadTimeout:     30 * time.Second,
98
+			WriteTimeout:    30 * time.Second,
99
+			ShutdownTimeout: 10 * time.Second,
100
+		},
101
+		DB: DBConfig{
102
+			MaxConns:       10,
103
+			MinConns:       0,
104
+			ConnectTimeout: 5 * time.Second,
105
+		},
106
+		Log: LogConfig{
107
+			Level:  "info",
108
+			Format: "text",
109
+		},
110
+		Metrics: MetricsConfig{
111
+			Enabled: true,
112
+		},
113
+		Tracing: TracingConfig{
114
+			Enabled:     false,
115
+			SampleRate:  0.05,
116
+			ServiceName: "shithubd",
117
+		},
118
+		Session: SessionConfig{
119
+			MaxAge: 30 * 24 * time.Hour,
120
+			Secure: false,
121
+		},
122
+	}
123
+}
124
+
125
+// Load resolves configuration in the documented precedence order. CLI
126
+// overrides may be nil. The TOML file is optional — its absence is not an
127
+// error; its existence with bad syntax IS.
128
+func Load(cliOverrides map[string]string) (Config, error) {
129
+	cfg := Defaults()
130
+	if err := mergeFile(&cfg); err != nil {
131
+		return cfg, err
132
+	}
133
+	if err := mergeEnv(&cfg, os.Environ()); err != nil {
134
+		return cfg, err
135
+	}
136
+	if err := mergeFlags(&cfg, cliOverrides); err != nil {
137
+		return cfg, err
138
+	}
139
+	applyAliases(&cfg)
140
+	if err := Validate(&cfg); err != nil {
141
+		return cfg, err
142
+	}
143
+	return cfg, nil
144
+}
145
+
146
+// applyAliases honors well-known env-var aliases that don't follow the
147
+// nested-key convention. SHITHUB_DATABASE_URL is the S01-era name for
148
+// db.url and remains supported.
149
+func applyAliases(cfg *Config) {
150
+	if cfg.DB.URL == "" {
151
+		if v := os.Getenv("SHITHUB_DATABASE_URL"); v != "" {
152
+			cfg.DB.URL = v
153
+		}
154
+	}
155
+	if cfg.Session.KeyB64 == "" {
156
+		if v := os.Getenv("SHITHUB_SESSION_KEY"); v != "" {
157
+			cfg.Session.KeyB64 = v
158
+		}
159
+	}
160
+}
161
+
162
+// Validate enforces invariants. Errors are precise enough to point at the
163
+// offending key.
164
+func Validate(c *Config) error {
165
+	switch strings.ToLower(c.Env) {
166
+	case "dev", "staging", "prod":
167
+		c.Env = strings.ToLower(c.Env)
168
+	default:
169
+		return fmt.Errorf("config: env: must be dev|staging|prod, got %q", c.Env)
170
+	}
171
+	switch strings.ToLower(c.Log.Level) {
172
+	case "debug", "info", "warn", "error":
173
+		c.Log.Level = strings.ToLower(c.Log.Level)
174
+	default:
175
+		return fmt.Errorf("config: log.level: must be debug|info|warn|error, got %q", c.Log.Level)
176
+	}
177
+	switch strings.ToLower(c.Log.Format) {
178
+	case "text", "json":
179
+		c.Log.Format = strings.ToLower(c.Log.Format)
180
+	default:
181
+		return fmt.Errorf("config: log.format: must be text|json, got %q", c.Log.Format)
182
+	}
183
+	if c.Web.Addr == "" {
184
+		return errors.New("config: web.addr is required")
185
+	}
186
+	if c.Tracing.Enabled && c.Tracing.Endpoint == "" {
187
+		return errors.New("config: tracing.endpoint is required when tracing.enabled=true")
188
+	}
189
+	if c.Tracing.SampleRate < 0 || c.Tracing.SampleRate > 1 {
190
+		return fmt.Errorf("config: tracing.sample_rate: must be in [0, 1], got %v", c.Tracing.SampleRate)
191
+	}
192
+	return nil
193
+}
194
+
195
+// mergeFile reads the TOML file (when present) over cfg.
196
+func mergeFile(cfg *Config) error {
197
+	path := os.Getenv("SHITHUB_CONFIG")
198
+	if path == "" {
199
+		path = "/etc/shithub/config.toml"
200
+	}
201
+	body, err := os.ReadFile(path) //nolint:gosec // operator-supplied path
202
+	if err != nil {
203
+		if errors.Is(err, os.ErrNotExist) {
204
+			return nil
205
+		}
206
+		return fmt.Errorf("config: read %s: %w", path, err)
207
+	}
208
+	if _, err := toml.Decode(string(body), cfg); err != nil {
209
+		return fmt.Errorf("config: parse %s: %w", path, err)
210
+	}
211
+	return nil
212
+}
213
+
214
+// mergeEnv overrides cfg from environment variables. Naming convention:
215
+// SHITHUB_<area>__<key> (double-underscore separates nested levels).
216
+// Single-segment keys also accept SHITHUB_<key>.
217
+func mergeEnv(cfg *Config, environ []string) error {
218
+	envMap := make(map[string]string, len(environ))
219
+	for _, kv := range environ {
220
+		if !strings.HasPrefix(kv, "SHITHUB_") {
221
+			continue
222
+		}
223
+		eq := strings.IndexByte(kv, '=')
224
+		if eq < 0 {
225
+			continue
226
+		}
227
+		key := strings.TrimPrefix(kv[:eq], "SHITHUB_")
228
+		envMap[key] = kv[eq+1:]
229
+	}
230
+	return walkAndApply(reflect.ValueOf(cfg).Elem(), reflect.TypeOf(*cfg), "", envMap)
231
+}
232
+
233
+// mergeFlags applies CLI overrides. Keys use TOML notation
234
+// ("web.addr", "tracing.endpoint", etc.).
235
+func mergeFlags(cfg *Config, overrides map[string]string) error {
236
+	if len(overrides) == 0 {
237
+		return nil
238
+	}
239
+	envStyle := make(map[string]string, len(overrides))
240
+	for k, v := range overrides {
241
+		envStyle[strings.ToUpper(strings.ReplaceAll(k, ".", "__"))] = v
242
+	}
243
+	return walkAndApply(reflect.ValueOf(cfg).Elem(), reflect.TypeOf(*cfg), "", envStyle)
244
+}
245
+
246
+// walkAndApply walks struct fields recursively, applying values from src
247
+// keyed by uppercased dot-then-double-underscore-joined paths.
248
+func walkAndApply(v reflect.Value, t reflect.Type, prefix string, src map[string]string) error {
249
+	for i := 0; i < t.NumField(); i++ {
250
+		field := t.Field(i)
251
+		tag := field.Tag.Get("toml")
252
+		if tag == "" || tag == "-" {
253
+			continue
254
+		}
255
+		fieldPath := strings.ToUpper(tag)
256
+		if prefix != "" {
257
+			fieldPath = prefix + "__" + fieldPath
258
+		}
259
+		fv := v.Field(i)
260
+
261
+		if field.Type.Kind() == reflect.Struct && field.Type != reflect.TypeOf(time.Duration(0)) {
262
+			if err := walkAndApply(fv, field.Type, fieldPath, src); err != nil {
263
+				return err
264
+			}
265
+			continue
266
+		}
267
+
268
+		raw, ok := src[fieldPath]
269
+		if !ok {
270
+			continue
271
+		}
272
+		if err := setField(fv, field.Type, raw); err != nil {
273
+			return fmt.Errorf("config: %s: %w", strings.ReplaceAll(strings.ToLower(fieldPath), "__", "."), err)
274
+		}
275
+	}
276
+	return nil
277
+}
278
+
279
+func setField(v reflect.Value, t reflect.Type, raw string) error {
280
+	switch t.Kind() {
281
+	case reflect.String:
282
+		v.SetString(raw)
283
+	case reflect.Bool:
284
+		b, err := strconv.ParseBool(raw)
285
+		if err != nil {
286
+			return fmt.Errorf("invalid bool: %w", err)
287
+		}
288
+		v.SetBool(b)
289
+	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
290
+		if t == reflect.TypeOf(time.Duration(0)) {
291
+			d, err := time.ParseDuration(raw)
292
+			if err != nil {
293
+				return fmt.Errorf("invalid duration: %w", err)
294
+			}
295
+			v.SetInt(int64(d))
296
+			return nil
297
+		}
298
+		n, err := strconv.ParseInt(raw, 10, 64)
299
+		if err != nil {
300
+			return fmt.Errorf("invalid int: %w", err)
301
+		}
302
+		v.SetInt(n)
303
+	case reflect.Float32, reflect.Float64:
304
+		f, err := strconv.ParseFloat(raw, 64)
305
+		if err != nil {
306
+			return fmt.Errorf("invalid float: %w", err)
307
+		}
308
+		v.SetFloat(f)
309
+	default:
310
+		return fmt.Errorf("unsupported field kind %s", t.Kind())
311
+	}
312
+	return nil
313
+}
internal/infra/config/config_test.goadded
@@ -0,0 +1,110 @@
1
+// SPDX-License-Identifier: AGPL-3.0-or-later
2
+
3
+package config
4
+
5
+import (
6
+	"strings"
7
+	"testing"
8
+	"time"
9
+)
10
+
11
+func TestDefaults_Validate(t *testing.T) {
12
+	t.Parallel()
13
+	cfg := Defaults()
14
+	if err := Validate(&cfg); err != nil {
15
+		t.Fatalf("Validate(Defaults()): %v", err)
16
+	}
17
+}
18
+
19
+func TestValidate_RejectsBadEnv(t *testing.T) {
20
+	t.Parallel()
21
+	cfg := Defaults()
22
+	cfg.Env = "production" // typo of "prod"
23
+	if err := Validate(&cfg); err == nil {
24
+		t.Errorf("expected validation error for env=production")
25
+	}
26
+}
27
+
28
+func TestValidate_RejectsTracingWithoutEndpoint(t *testing.T) {
29
+	t.Parallel()
30
+	cfg := Defaults()
31
+	cfg.Tracing.Enabled = true
32
+	if err := Validate(&cfg); err == nil {
33
+		t.Errorf("expected validation error when tracing.enabled=true and endpoint empty")
34
+	}
35
+}
36
+
37
+func TestValidate_RejectsBadSampleRate(t *testing.T) {
38
+	t.Parallel()
39
+	cfg := Defaults()
40
+	cfg.Tracing.Enabled = true
41
+	cfg.Tracing.Endpoint = "http://otel:4318"
42
+	cfg.Tracing.SampleRate = 2.0
43
+	if err := Validate(&cfg); err == nil {
44
+		t.Errorf("expected validation error for sample_rate=2.0")
45
+	}
46
+}
47
+
48
+func TestMergeEnv_AppliesNestedKeys(t *testing.T) {
49
+	t.Parallel()
50
+	cfg := Defaults()
51
+	env := []string{
52
+		"SHITHUB_WEB__ADDR=:9090",
53
+		"SHITHUB_DB__MAX_CONNS=42",
54
+		"SHITHUB_TRACING__ENABLED=true",
55
+		"SHITHUB_TRACING__ENDPOINT=http://otel:4318",
56
+		"SHITHUB_DB__CONNECT_TIMEOUT=8s",
57
+	}
58
+	if err := mergeEnv(&cfg, env); err != nil {
59
+		t.Fatalf("mergeEnv: %v", err)
60
+	}
61
+	if cfg.Web.Addr != ":9090" {
62
+		t.Errorf("Web.Addr: got %q", cfg.Web.Addr)
63
+	}
64
+	if cfg.DB.MaxConns != 42 {
65
+		t.Errorf("DB.MaxConns: got %d", cfg.DB.MaxConns)
66
+	}
67
+	if !cfg.Tracing.Enabled {
68
+		t.Errorf("Tracing.Enabled: not set")
69
+	}
70
+	if cfg.DB.ConnectTimeout != 8*time.Second {
71
+		t.Errorf("DB.ConnectTimeout: got %v", cfg.DB.ConnectTimeout)
72
+	}
73
+}
74
+
75
+func TestPrintRedacted_HidesSecrets(t *testing.T) {
76
+	t.Parallel()
77
+	cfg := Defaults()
78
+	cfg.DB.URL = "postgres://shithub:hunter2@localhost/shithub"
79
+	cfg.Session.KeyB64 = "supersecretkey"
80
+	cfg.Metrics.BasicAuthPass = "metrics-pass"
81
+	cfg.ErrorReporting.DSN = "https://abc@sentry.example/1"
82
+
83
+	out, err := PrintRedacted(cfg)
84
+	if err != nil {
85
+		t.Fatalf("PrintRedacted: %v", err)
86
+	}
87
+
88
+	for _, leak := range []string{"hunter2", "supersecretkey", "metrics-pass", "https://abc@sentry"} {
89
+		if strings.Contains(out, leak) {
90
+			t.Errorf("PrintRedacted leaked %q\noutput: %s", leak, out)
91
+		}
92
+	}
93
+	if !strings.Contains(out, "***") {
94
+		t.Errorf("PrintRedacted produced no *** redactions; output:\n%s", out)
95
+	}
96
+}
97
+
98
+func TestMergeFlags_OverridesEnv(t *testing.T) {
99
+	t.Parallel()
100
+	cfg := Defaults()
101
+	if err := mergeEnv(&cfg, []string{"SHITHUB_WEB__ADDR=:9090"}); err != nil {
102
+		t.Fatalf("mergeEnv: %v", err)
103
+	}
104
+	if err := mergeFlags(&cfg, map[string]string{"web.addr": ":7777"}); err != nil {
105
+		t.Fatalf("mergeFlags: %v", err)
106
+	}
107
+	if cfg.Web.Addr != ":7777" {
108
+		t.Errorf("Web.Addr: got %q, want :7777", cfg.Web.Addr)
109
+	}
110
+}
internal/infra/config/redact.goadded
@@ -0,0 +1,72 @@
1
+// SPDX-License-Identifier: AGPL-3.0-or-later
2
+
3
+package config
4
+
5
+import (
6
+	"fmt"
7
+	"reflect"
8
+	"strings"
9
+
10
+	"github.com/BurntSushi/toml"
11
+)
12
+
13
+// secretFieldNames lists case-insensitive substrings that mark a field as
14
+// a secret. Matches are deliberately broad — better to redact a non-secret
15
+// than leak one. URL fields are included because connection URLs commonly
16
+// carry credentials in the userinfo component.
17
+var secretFieldNames = []string{
18
+	"password", "pass",
19
+	"secret",
20
+	"key",
21
+	"token",
22
+	"dsn",
23
+	"url",
24
+}
25
+
26
+// PrintRedacted writes the config to w in TOML form with secrets replaced
27
+// by `***`. The transformation operates on a deep copy so the live Config
28
+// is never mutated.
29
+func PrintRedacted(c Config) (string, error) {
30
+	cp := redactCopy(reflect.ValueOf(c)).Interface().(Config)
31
+	var buf strings.Builder
32
+	if err := toml.NewEncoder(&buf).Encode(cp); err != nil {
33
+		return "", fmt.Errorf("config: encode: %w", err)
34
+	}
35
+	return buf.String(), nil
36
+}
37
+
38
+func redactCopy(v reflect.Value) reflect.Value {
39
+	out := reflect.New(v.Type()).Elem()
40
+	switch v.Kind() {
41
+	case reflect.Struct:
42
+		for i := 0; i < v.NumField(); i++ {
43
+			fv := v.Field(i)
44
+			ft := v.Type().Field(i)
45
+			if !ft.IsExported() {
46
+				continue
47
+			}
48
+			if fv.Kind() == reflect.Struct {
49
+				out.Field(i).Set(redactCopy(fv))
50
+				continue
51
+			}
52
+			if shouldRedact(ft.Name) && fv.Kind() == reflect.String && fv.String() != "" {
53
+				out.Field(i).SetString("***")
54
+				continue
55
+			}
56
+			out.Field(i).Set(fv)
57
+		}
58
+	default:
59
+		out.Set(v)
60
+	}
61
+	return out
62
+}
63
+
64
+func shouldRedact(fieldName string) bool {
65
+	lower := strings.ToLower(fieldName)
66
+	for _, needle := range secretFieldNames {
67
+		if strings.Contains(lower, needle) {
68
+			return true
69
+		}
70
+	}
71
+	return false
72
+}