Go · 8701 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 // Package ssrf is the canonical SSRF-defense layer (S35). The S33
4 // webhook deliverer was the first caller; future outbound-fetch
5 // paths (avatar mirroring, OG-image scraping, OAuth metadata
6 // fetches, etc.) plug into the same Config + HTTPClient.
7 //
8 // The pattern, copied from the webhook package and codified here:
9 //
10 // 1. Resolve the hostname.
11 // 2. Reject the request if ANY resolved IP falls in a private,
12 // loopback, link-local, CGNAT, multicast, or ULA range.
13 // 3. Dial the resolved IP directly so a re-resolve at connect
14 // time can't rebind to a private address (DNS rebinding).
15 // 4. Reject schemes other than http/https and ports outside the
16 // operator's allow-list (default 80/443/8080/8443).
17 //
18 // Operators with a self-hosted CI behind a private IP set
19 // AllowedHosts (exact-match, case-insensitive) which implicitly
20 // allows the host to bypass the IP block-list. Don't confuse with
21 // AllowPrivateNetworks (global bypass — for testing only).
22 package ssrf
23
24 import (
25 "context"
26 "errors"
27 "fmt"
28 "net"
29 "net/http"
30 "net/url"
31 "strconv"
32 "time"
33 )
34
35 // Error wraps a rejection so callers can `errors.As` to the typed
36 // shape and surface the specific reason in admin logs / UI.
37 type Error struct {
38 URL string
39 Reason string
40 }
41
42 func (e *Error) Error() string { return fmt.Sprintf("ssrf: %s: %s", e.Reason, e.URL) }
43
44 // Is reports whether err is or wraps an *Error.
45 func Is(err error) bool {
46 var s *Error
47 return errors.As(err, &s)
48 }
49
50 // Config tunes the defense. Defaults are safe for a single-tenant
51 // public deployment; self-hosters extend AllowedHosts / AllowedPorts.
52 type Config struct {
53 AllowedSchemes []string
54 AllowedPorts []int
55 AllowPrivateNetworks bool
56 AllowedHosts []string
57 DialTimeout time.Duration
58 RequestTimeout time.Duration
59 Resolver *net.Resolver
60 }
61
62 // Default returns the production defaults.
63 func Default() Config {
64 return Config{
65 AllowedSchemes: []string{"http", "https"},
66 AllowedPorts: []int{80, 443, 8080, 8443},
67 DialTimeout: 10 * time.Second,
68 RequestTimeout: 30 * time.Second,
69 }
70 }
71
72 // HTTPClient returns a redirect-following-disabled client whose
73 // dialer enforces the SSRF rules. 3xx is treated as success-like
74 // (no-follow) since following is itself an SSRF amplifier.
75 func (c Config) HTTPClient() *http.Client {
76 cfg := c.applyDefaults()
77 tr := &http.Transport{
78 DialContext: cfg.dialContext,
79 ResponseHeaderTimeout: cfg.RequestTimeout,
80 ForceAttemptHTTP2: false,
81 DisableKeepAlives: true,
82 }
83 return &http.Client{
84 Transport: tr,
85 Timeout: cfg.RequestTimeout,
86 CheckRedirect: func(*http.Request, []*http.Request) error {
87 return http.ErrUseLastResponse
88 },
89 }
90 }
91
92 // ValidateWithResolve runs Validate plus a DNS resolve so callers can
93 // reject loopback/private/CGNAT/multicast hosts at create-time, not
94 // only at delivery-time. dialContext still re-resolves on each dial
95 // (DNS rebinding defense) — this is the cheap-but-thorough gate
96 // admin forms call so the persisted hook can't sit broken-on-arrival
97 // (SR2 H3). It's NOT a substitute for the dial-time check.
98 //
99 // A nil Resolver uses net.DefaultResolver. AllowedHosts (exact match
100 // case-insensitive) bypasses the IP block-list as in dialContext.
101 func (c Config) ValidateWithResolve(ctx context.Context, rawURL string) error {
102 if err := c.Validate(rawURL); err != nil {
103 return err
104 }
105 cfg := c.applyDefaults()
106 u, err := url.Parse(rawURL)
107 if err != nil {
108 return &Error{URL: rawURL, Reason: "malformed URL"}
109 }
110 host := u.Hostname()
111 if stringSetContainsFold(cfg.AllowedHosts, host) {
112 return nil
113 }
114 if cfg.AllowPrivateNetworks {
115 return nil
116 }
117 // IP literal — check directly without DNS.
118 if ip := net.ParseIP(host); ip != nil {
119 if IsForbiddenIP(ip) {
120 return &Error{URL: rawURL, Reason: "host resolves to forbidden IP " + ip.String()}
121 }
122 return nil
123 }
124 // Hostname — resolve and check every result.
125 resolver := cfg.Resolver
126 if resolver == nil {
127 resolver = net.DefaultResolver
128 }
129 ips, err := resolver.LookupIPAddr(ctx, host)
130 if err != nil {
131 return &Error{URL: rawURL, Reason: "DNS resolve: " + err.Error()}
132 }
133 if len(ips) == 0 {
134 return &Error{URL: rawURL, Reason: "no IPs resolved"}
135 }
136 for _, ipa := range ips {
137 if IsForbiddenIP(ipa.IP) {
138 return &Error{URL: rawURL, Reason: "host resolves to forbidden IP " + ipa.IP.String()}
139 }
140 }
141 return nil
142 }
143
144 // Validate runs the syntactic gate (scheme, port, host shape) without
145 // resolving DNS. dialContext re-runs the IP-level checks at connect
146 // time so a passing Validate doesn't imply the URL is safe to dial —
147 // it's the early-rejection cheap gate. For full create-time rejection
148 // (loopback, private IPs, etc.) use ValidateWithResolve.
149 func (c Config) Validate(rawURL string) error {
150 cfg := c.applyDefaults()
151 u, err := url.Parse(rawURL)
152 if err != nil {
153 return &Error{URL: rawURL, Reason: "malformed URL"}
154 }
155 if !stringSetContains(cfg.AllowedSchemes, u.Scheme) {
156 return &Error{URL: rawURL, Reason: "scheme " + u.Scheme + " not allowed"}
157 }
158 host := u.Hostname()
159 if host == "" {
160 return &Error{URL: rawURL, Reason: "missing host"}
161 }
162 port := u.Port()
163 if port == "" {
164 switch u.Scheme {
165 case "http":
166 port = "80"
167 case "https":
168 port = "443"
169 }
170 }
171 pn, perr := strconv.Atoi(port)
172 if perr != nil || pn <= 0 || pn > 65535 {
173 return &Error{URL: rawURL, Reason: "invalid port"}
174 }
175 if !intSetContains(cfg.AllowedPorts, pn) {
176 return &Error{URL: rawURL, Reason: "port " + port + " not in allow-list"}
177 }
178 return nil
179 }
180
181 func (c Config) dialContext(ctx context.Context, network, addr string) (net.Conn, error) {
182 host, port, err := net.SplitHostPort(addr)
183 if err != nil {
184 return nil, &Error{URL: addr, Reason: "split host:port: " + err.Error()}
185 }
186 pn, _ := strconv.Atoi(port)
187 if !intSetContains(c.AllowedPorts, pn) {
188 return nil, &Error{URL: addr, Reason: "port " + port + " not in allow-list"}
189 }
190 hostAllowed := stringSetContainsFold(c.AllowedHosts, host)
191 resolver := c.Resolver
192 if resolver == nil {
193 resolver = net.DefaultResolver
194 }
195 ips, err := resolver.LookupIPAddr(ctx, host)
196 if err != nil {
197 return nil, &Error{URL: addr, Reason: "DNS resolve: " + err.Error()}
198 }
199 if len(ips) == 0 {
200 return nil, &Error{URL: addr, Reason: "no IPs resolved"}
201 }
202 for _, ipa := range ips {
203 if !hostAllowed && !c.AllowPrivateNetworks && IsForbiddenIP(ipa.IP) {
204 return nil, &Error{URL: addr, Reason: "resolved to forbidden IP " + ipa.IP.String()}
205 }
206 }
207 dialer := &net.Dialer{Timeout: c.DialTimeout}
208 dialAddr := net.JoinHostPort(ips[0].IP.String(), port)
209 return dialer.DialContext(ctx, network, dialAddr)
210 }
211
212 func (c Config) applyDefaults() Config {
213 def := Default()
214 if len(c.AllowedSchemes) == 0 {
215 c.AllowedSchemes = def.AllowedSchemes
216 }
217 if len(c.AllowedPorts) == 0 {
218 c.AllowedPorts = def.AllowedPorts
219 }
220 if c.DialTimeout == 0 {
221 c.DialTimeout = def.DialTimeout
222 }
223 if c.RequestTimeout == 0 {
224 c.RequestTimeout = def.RequestTimeout
225 }
226 return c
227 }
228
229 // IsForbiddenIP reports whether the IP belongs to any of the ranges
230 // the spec marks as off-limits. Exposed so callers running their own
231 // validation (e.g. a config-time URL probe) can reuse the check.
232 func IsForbiddenIP(ip net.IP) bool {
233 if ip == nil {
234 return true
235 }
236 if ip.IsUnspecified() || ip.IsLoopback() || ip.IsLinkLocalUnicast() ||
237 ip.IsLinkLocalMulticast() || ip.IsInterfaceLocalMulticast() ||
238 ip.IsMulticast() {
239 return true
240 }
241 if ip4 := ip.To4(); ip4 != nil {
242 switch {
243 case ip4[0] == 10:
244 return true
245 case ip4[0] == 172 && ip4[1] >= 16 && ip4[1] <= 31:
246 return true
247 case ip4[0] == 192 && ip4[1] == 168:
248 return true
249 case ip4[0] == 100 && ip4[1] >= 64 && ip4[1] <= 127:
250 return true
251 case ip4[0] == 169 && ip4[1] == 254:
252 return true
253 case ip4[0] == 0:
254 return true
255 case ip4[0] == 255:
256 return true
257 }
258 return false
259 }
260 if len(ip) == net.IPv6len && (ip[0]&0xfe) == 0xfc {
261 return true
262 }
263 return false
264 }
265
266 func stringSetContains(set []string, v string) bool {
267 for _, s := range set {
268 if s == v {
269 return true
270 }
271 }
272 return false
273 }
274
275 func stringSetContainsFold(set []string, v string) bool {
276 for _, s := range set {
277 if equalFold(s, v) {
278 return true
279 }
280 }
281 return false
282 }
283
284 func equalFold(a, b string) bool {
285 if len(a) != len(b) {
286 return false
287 }
288 for i := 0; i < len(a); i++ {
289 ca, cb := a[i], b[i]
290 if 'A' <= ca && ca <= 'Z' {
291 ca += 'a' - 'A'
292 }
293 if 'A' <= cb && cb <= 'Z' {
294 cb += 'a' - 'A'
295 }
296 if ca != cb {
297 return false
298 }
299 }
300 return true
301 }
302
303 func intSetContains(set []int, v int) bool {
304 for _, s := range set {
305 if s == v {
306 return true
307 }
308 }
309 return false
310 }
311