tenseleyflow/shithub / 4a9b2e1

Browse files

S33: SSRF defense (block-list + dial-the-IP transport)

Authored by espadonne
SHA
4a9b2e14e57c25f17634689b5ca4928c563a049b
Parents
0ebc1ff
Tree
a106067

2 changed files

StatusFile+-
A internal/webhook/ssrf.go 293 0
A internal/webhook/ssrf_test.go 98 0
internal/webhook/ssrf.goadded
@@ -0,0 +1,293 @@
1
+// SPDX-License-Identifier: AGPL-3.0-or-later
2
+
3
+// Package webhook owns outbound webhook delivery: signing, SSRF
4
+// defense, retry/backoff, and the deliverer + fanout workers.
5
+//
6
+// SSRF philosophy: webhooks point at attacker-controlled URLs by
7
+// design. The defense pattern is documented in `docs/internal/webhooks.md`
8
+// and enforced in this file:
9
+//
10
+//  1. Resolve the hostname to a set of IPs.
11
+//  2. Reject the request if ANY resolved IP is in a private/loopback/
12
+//     link-local/etc. range — even if other IPs would have been fine.
13
+//     A mixed-result hostname is suspicious enough to refuse.
14
+//  3. Pick a public IP and dial it directly, passing the original
15
+//     hostname for SNI / Host header. This defeats DNS-rebinding
16
+//     because the IP we validated is the IP we connect to (no second
17
+//     resolve at dial time).
18
+//  4. Reject schemes other than http/https and ports outside the
19
+//     well-known web ports unless the operator allow-listed them.
20
+package webhook
21
+
22
+import (
23
+	"context"
24
+	"errors"
25
+	"fmt"
26
+	"net"
27
+	"net/http"
28
+	"net/url"
29
+	"strconv"
30
+	"time"
31
+)
32
+
33
+// SSRFError describes why a URL was rejected pre-flight or at dial
34
+// time. The error string is operator-friendly; we don't surface the
35
+// exact reason to the deliverer's external counterpart so the message
36
+// stays in our own logs / UI.
37
+type SSRFError struct {
38
+	URL    string
39
+	Reason string
40
+}
41
+
42
+func (e *SSRFError) Error() string { return fmt.Sprintf("ssrf: %s: %s", e.Reason, e.URL) }
43
+
44
+// SSRFConfig tunes the defense. Defaults are safe for a single-tenant
45
+// public deployment; self-hosters extend AllowedHosts / AllowedPorts
46
+// when delivering to internal CI behind a private IP.
47
+type SSRFConfig struct {
48
+	// AllowedSchemes restricts URL schemes. Default ["http", "https"].
49
+	AllowedSchemes []string
50
+	// AllowedPorts is the set of TCP ports the deliverer is willing to
51
+	// dial. Defaults to {80, 443, 8080, 8443}; operators add internal
52
+	// ports here.
53
+	AllowedPorts []int
54
+	// AllowPrivateNetworks, when true, skips the IP block-list. Use ONLY
55
+	// with a paired AllowedHosts list — the combination lets a self-
56
+	// hoster point a webhook at `ci.internal` while still rejecting any
57
+	// other hostname that would resolve to a private IP.
58
+	AllowPrivateNetworks bool
59
+	// AllowedHosts is a hostname allow-list. When non-empty AND a
60
+	// hostname matches, AllowPrivateNetworks is implicitly applied for
61
+	// that hostname only. Match is exact (no wildcards) and case-
62
+	// insensitive.
63
+	AllowedHosts []string
64
+	// DialTimeout caps the per-dial connect time. Default 10s.
65
+	DialTimeout time.Duration
66
+	// RequestTimeout caps the total request time (connect + read).
67
+	// Default 30s per the spec.
68
+	RequestTimeout time.Duration
69
+	// Resolver is plumbed for tests. nil => net.DefaultResolver.
70
+	Resolver *net.Resolver
71
+}
72
+
73
+// DefaultSSRFConfig returns the production defaults. Callers add to
74
+// the slices as needed; the zero-value SSRFConfig is also valid (it
75
+// will pick the same defaults at validation time).
76
+func DefaultSSRFConfig() SSRFConfig {
77
+	return SSRFConfig{
78
+		AllowedSchemes: []string{"http", "https"},
79
+		AllowedPorts:   []int{80, 443, 8080, 8443},
80
+		DialTimeout:    10 * time.Second,
81
+		RequestTimeout: 30 * time.Second,
82
+	}
83
+}
84
+
85
+// HTTPClient returns an *http.Client configured with the SSRF-safe
86
+// dialer. The transport intentionally disables redirect-following:
87
+// 3xx is treated as success and a redirect target's IP would otherwise
88
+// bypass our pre-flight check.
89
+func (c SSRFConfig) HTTPClient() *http.Client {
90
+	cfg := c.applyDefaults()
91
+	tr := &http.Transport{
92
+		DialContext:           cfg.dialContext,
93
+		ResponseHeaderTimeout: cfg.RequestTimeout,
94
+		ForceAttemptHTTP2:     false,
95
+		// No keep-alive across deliveries — webhooks are sparse and
96
+		// connection reuse complicates the validate-then-dial chain.
97
+		DisableKeepAlives: true,
98
+	}
99
+	return &http.Client{
100
+		Transport: tr,
101
+		Timeout:   cfg.RequestTimeout,
102
+		CheckRedirect: func(*http.Request, []*http.Request) error {
103
+			return http.ErrUseLastResponse
104
+		},
105
+	}
106
+}
107
+
108
+// Validate checks the URL shape (scheme/port/host) without resolving
109
+// DNS. Returns *SSRFError on rejection. The deliverer also re-resolves
110
+// at dial time inside dialContext to defeat rebinding.
111
+func (c SSRFConfig) Validate(rawURL string) error {
112
+	cfg := c.applyDefaults()
113
+	u, err := url.Parse(rawURL)
114
+	if err != nil {
115
+		return &SSRFError{URL: rawURL, Reason: "malformed URL"}
116
+	}
117
+	if !stringSetContains(cfg.AllowedSchemes, u.Scheme) {
118
+		return &SSRFError{URL: rawURL, Reason: "scheme " + u.Scheme + " not allowed"}
119
+	}
120
+	host := u.Hostname()
121
+	if host == "" {
122
+		return &SSRFError{URL: rawURL, Reason: "missing host"}
123
+	}
124
+	port := u.Port()
125
+	if port == "" {
126
+		switch u.Scheme {
127
+		case "http":
128
+			port = "80"
129
+		case "https":
130
+			port = "443"
131
+		}
132
+	}
133
+	pn, perr := strconv.Atoi(port)
134
+	if perr != nil || pn <= 0 || pn > 65535 {
135
+		return &SSRFError{URL: rawURL, Reason: "invalid port"}
136
+	}
137
+	if !intSetContains(cfg.AllowedPorts, pn) {
138
+		return &SSRFError{URL: rawURL, Reason: "port " + port + " not in allow-list"}
139
+	}
140
+	return nil
141
+}
142
+
143
+// dialContext is the SSRF-safe dialer. It re-resolves the hostname at
144
+// dial time, validates every returned IP, and connects to the first
145
+// allowed IP using the original hostname for SNI.
146
+func (c SSRFConfig) dialContext(ctx context.Context, network, addr string) (net.Conn, error) {
147
+	host, port, err := net.SplitHostPort(addr)
148
+	if err != nil {
149
+		return nil, &SSRFError{URL: addr, Reason: "split host:port: " + err.Error()}
150
+	}
151
+	pn, _ := strconv.Atoi(port)
152
+	if !intSetContains(c.AllowedPorts, pn) {
153
+		return nil, &SSRFError{URL: addr, Reason: "port " + port + " not in allow-list"}
154
+	}
155
+
156
+	hostAllowed := stringSetContainsFold(c.AllowedHosts, host)
157
+	resolver := c.Resolver
158
+	if resolver == nil {
159
+		resolver = net.DefaultResolver
160
+	}
161
+	ips, err := resolver.LookupIPAddr(ctx, host)
162
+	if err != nil {
163
+		return nil, &SSRFError{URL: addr, Reason: "DNS resolve: " + err.Error()}
164
+	}
165
+	if len(ips) == 0 {
166
+		return nil, &SSRFError{URL: addr, Reason: "no IPs resolved"}
167
+	}
168
+	// Reject if ANY IP is forbidden — a mixed-result hostname is
169
+	// suspicious enough to refuse. The exception is when the host is
170
+	// allow-listed (self-hoster scenario).
171
+	for _, ipa := range ips {
172
+		if !hostAllowed && !c.AllowPrivateNetworks && isForbiddenIP(ipa.IP) {
173
+			return nil, &SSRFError{URL: addr, Reason: "resolved to forbidden IP " + ipa.IP.String()}
174
+		}
175
+	}
176
+	// Dial the first IP. We pass the literal IP so the dialer doesn't
177
+	// re-resolve under us; the URL's Host header (set by net/http) keeps
178
+	// the original hostname for routing/SNI.
179
+	dialer := &net.Dialer{Timeout: c.DialTimeout}
180
+	dialAddr := net.JoinHostPort(ips[0].IP.String(), port)
181
+	return dialer.DialContext(ctx, network, dialAddr)
182
+}
183
+
184
+// applyDefaults fills in zero-value fields with defaults. Returns a
185
+// copy so the caller's struct stays unchanged.
186
+func (c SSRFConfig) applyDefaults() SSRFConfig {
187
+	def := DefaultSSRFConfig()
188
+	if len(c.AllowedSchemes) == 0 {
189
+		c.AllowedSchemes = def.AllowedSchemes
190
+	}
191
+	if len(c.AllowedPorts) == 0 {
192
+		c.AllowedPorts = def.AllowedPorts
193
+	}
194
+	if c.DialTimeout == 0 {
195
+		c.DialTimeout = def.DialTimeout
196
+	}
197
+	if c.RequestTimeout == 0 {
198
+		c.RequestTimeout = def.RequestTimeout
199
+	}
200
+	return c
201
+}
202
+
203
+// isForbiddenIP returns true if the IP belongs to any of the ranges
204
+// the spec marks as off-limits.
205
+func isForbiddenIP(ip net.IP) bool {
206
+	if ip == nil {
207
+		return true
208
+	}
209
+	if ip.IsUnspecified() || ip.IsLoopback() || ip.IsLinkLocalUnicast() ||
210
+		ip.IsLinkLocalMulticast() || ip.IsInterfaceLocalMulticast() ||
211
+		ip.IsMulticast() {
212
+		return true
213
+	}
214
+	// IPv4 RFC 1918 + CGNAT (100.64/10) + broadcast + the autoconf
215
+	// 169.254/16 range (already covered by IsLinkLocalUnicast but
216
+	// belt-and-braces).
217
+	if ip4 := ip.To4(); ip4 != nil {
218
+		switch {
219
+		case ip4[0] == 10:
220
+			return true
221
+		case ip4[0] == 172 && ip4[1] >= 16 && ip4[1] <= 31:
222
+			return true
223
+		case ip4[0] == 192 && ip4[1] == 168:
224
+			return true
225
+		case ip4[0] == 100 && ip4[1] >= 64 && ip4[1] <= 127:
226
+			return true
227
+		case ip4[0] == 169 && ip4[1] == 254:
228
+			return true
229
+		case ip4[0] == 0:
230
+			return true
231
+		case ip4[0] == 255:
232
+			return true
233
+		}
234
+		return false
235
+	}
236
+	// IPv6 unique-local addresses (fc00::/7) — covers fd00::/8 too.
237
+	if len(ip) == net.IPv6len && (ip[0]&0xfe) == 0xfc {
238
+		return true
239
+	}
240
+	return false
241
+}
242
+
243
+func stringSetContains(set []string, v string) bool {
244
+	for _, s := range set {
245
+		if s == v {
246
+			return true
247
+		}
248
+	}
249
+	return false
250
+}
251
+
252
+func stringSetContainsFold(set []string, v string) bool {
253
+	for _, s := range set {
254
+		if equalFold(s, v) {
255
+			return true
256
+		}
257
+	}
258
+	return false
259
+}
260
+
261
+func equalFold(a, b string) bool {
262
+	if len(a) != len(b) {
263
+		return false
264
+	}
265
+	for i := 0; i < len(a); i++ {
266
+		ca, cb := a[i], b[i]
267
+		if 'A' <= ca && ca <= 'Z' {
268
+			ca += 'a' - 'A'
269
+		}
270
+		if 'A' <= cb && cb <= 'Z' {
271
+			cb += 'a' - 'A'
272
+		}
273
+		if ca != cb {
274
+			return false
275
+		}
276
+	}
277
+	return true
278
+}
279
+
280
+func intSetContains(set []int, v int) bool {
281
+	for _, s := range set {
282
+		if s == v {
283
+			return true
284
+		}
285
+	}
286
+	return false
287
+}
288
+
289
+// IsSSRF reports whether err is or wraps an SSRFError.
290
+func IsSSRF(err error) bool {
291
+	var s *SSRFError
292
+	return errors.As(err, &s)
293
+}
internal/webhook/ssrf_test.goadded
@@ -0,0 +1,98 @@
1
+// SPDX-License-Identifier: AGPL-3.0-or-later
2
+
3
+package webhook
4
+
5
+import (
6
+	"net"
7
+	"strings"
8
+	"testing"
9
+)
10
+
11
+func TestValidateRejectsBadShapes(t *testing.T) {
12
+	c := DefaultSSRFConfig()
13
+	cases := []struct {
14
+		name, url, wantSubstr string
15
+	}{
16
+		{"empty", "", "scheme  not allowed"},
17
+		{"file scheme", "file:///etc/passwd", "scheme file not allowed"},
18
+		{"ftp scheme", "ftp://example.com/", "scheme ftp not allowed"},
19
+		{"missing host", "http:///path", "missing host"},
20
+		{"non-allowed port", "http://example.com:9999/x", "port 9999"},
21
+	}
22
+	for _, tc := range cases {
23
+		t.Run(tc.name, func(t *testing.T) {
24
+			err := c.Validate(tc.url)
25
+			if err == nil {
26
+				t.Fatalf("Validate(%q) = nil; want SSRFError", tc.url)
27
+			}
28
+			if !strings.Contains(err.Error(), tc.wantSubstr) {
29
+				t.Fatalf("Validate(%q) = %q; want substring %q", tc.url, err, tc.wantSubstr)
30
+			}
31
+		})
32
+	}
33
+}
34
+
35
+func TestValidatePassesGoodShapes(t *testing.T) {
36
+	c := DefaultSSRFConfig()
37
+	cases := []string{
38
+		"http://example.com/x",
39
+		"https://example.com:443/x",
40
+		"http://example.com:8080/y",
41
+		"https://example.com:8443/y",
42
+	}
43
+	for _, u := range cases {
44
+		if err := c.Validate(u); err != nil {
45
+			t.Fatalf("Validate(%q) = %v; want nil", u, err)
46
+		}
47
+	}
48
+}
49
+
50
+func TestIsForbiddenIPClassifiesCorrectly(t *testing.T) {
51
+	forbidden := []string{
52
+		"127.0.0.1", "127.255.255.254",
53
+		"10.0.0.1", "10.255.255.255",
54
+		"172.16.0.1", "172.31.255.255",
55
+		"192.168.0.1",
56
+		"100.64.0.1",       // CGNAT
57
+		"169.254.169.254", // AWS metadata service
58
+		"0.0.0.0",
59
+		"255.255.255.255",
60
+		"::1",
61
+		"fe80::1",        // link-local
62
+		"fd00::1",        // ULA
63
+		"fc00::1",        // ULA
64
+	}
65
+	for _, addr := range forbidden {
66
+		ip := net.ParseIP(addr)
67
+		if ip == nil {
68
+			t.Fatalf("bad test fixture: %q", addr)
69
+		}
70
+		if !isForbiddenIP(ip) {
71
+			t.Errorf("isForbiddenIP(%q) = false; want true", addr)
72
+		}
73
+	}
74
+	allowed := []string{
75
+		"1.1.1.1", "8.8.8.8", "203.0.113.5", "198.51.100.7",
76
+		"2001:4860:4860::8888",
77
+	}
78
+	for _, addr := range allowed {
79
+		ip := net.ParseIP(addr)
80
+		if ip == nil {
81
+			t.Fatalf("bad test fixture: %q", addr)
82
+		}
83
+		if isForbiddenIP(ip) {
84
+			t.Errorf("isForbiddenIP(%q) = true; want false", addr)
85
+		}
86
+	}
87
+}
88
+
89
+func TestIsSSRF(t *testing.T) {
90
+	c := DefaultSSRFConfig()
91
+	err := c.Validate("file:///etc/passwd")
92
+	if !IsSSRF(err) {
93
+		t.Fatalf("IsSSRF(%v) = false; want true", err)
94
+	}
95
+	if IsSSRF(nil) {
96
+		t.Fatalf("IsSSRF(nil) = true; want false")
97
+	}
98
+}