| 1 | // SPDX-License-Identifier: AGPL-3.0-or-later |
| 2 | |
| 3 | // Package ssrf is the canonical SSRF-defense layer (S35). The S33 |
| 4 | // webhook deliverer was the first caller; future outbound-fetch |
| 5 | // paths (avatar mirroring, OG-image scraping, OAuth metadata |
| 6 | // fetches, etc.) plug into the same Config + HTTPClient. |
| 7 | // |
| 8 | // The pattern, copied from the webhook package and codified here: |
| 9 | // |
| 10 | // 1. Resolve the hostname. |
| 11 | // 2. Reject the request if ANY resolved IP falls in a private, |
| 12 | // loopback, link-local, CGNAT, multicast, or ULA range. |
| 13 | // 3. Dial the resolved IP directly so a re-resolve at connect |
| 14 | // time can't rebind to a private address (DNS rebinding). |
| 15 | // 4. Reject schemes other than http/https and ports outside the |
| 16 | // operator's allow-list (default 80/443/8080/8443). |
| 17 | // |
| 18 | // Operators with a self-hosted CI behind a private IP set |
| 19 | // AllowedHosts (exact-match, case-insensitive) which implicitly |
| 20 | // allows the host to bypass the IP block-list. Don't confuse with |
| 21 | // AllowPrivateNetworks (global bypass — for testing only). |
| 22 | package ssrf |
| 23 | |
| 24 | import ( |
| 25 | "context" |
| 26 | "errors" |
| 27 | "fmt" |
| 28 | "net" |
| 29 | "net/http" |
| 30 | "net/url" |
| 31 | "strconv" |
| 32 | "time" |
| 33 | ) |
| 34 | |
| 35 | // Error wraps a rejection so callers can `errors.As` to the typed |
| 36 | // shape and surface the specific reason in admin logs / UI. |
| 37 | type Error struct { |
| 38 | URL string |
| 39 | Reason string |
| 40 | } |
| 41 | |
| 42 | func (e *Error) Error() string { return fmt.Sprintf("ssrf: %s: %s", e.Reason, e.URL) } |
| 43 | |
| 44 | // Is reports whether err is or wraps an *Error. |
| 45 | func Is(err error) bool { |
| 46 | var s *Error |
| 47 | return errors.As(err, &s) |
| 48 | } |
| 49 | |
| 50 | // Config tunes the defense. Defaults are safe for a single-tenant |
| 51 | // public deployment; self-hosters extend AllowedHosts / AllowedPorts. |
| 52 | type Config struct { |
| 53 | AllowedSchemes []string |
| 54 | AllowedPorts []int |
| 55 | AllowPrivateNetworks bool |
| 56 | AllowedHosts []string |
| 57 | DialTimeout time.Duration |
| 58 | RequestTimeout time.Duration |
| 59 | Resolver *net.Resolver |
| 60 | } |
| 61 | |
| 62 | // Default returns the production defaults. |
| 63 | func Default() Config { |
| 64 | return Config{ |
| 65 | AllowedSchemes: []string{"http", "https"}, |
| 66 | AllowedPorts: []int{80, 443, 8080, 8443}, |
| 67 | DialTimeout: 10 * time.Second, |
| 68 | RequestTimeout: 30 * time.Second, |
| 69 | } |
| 70 | } |
| 71 | |
| 72 | // HTTPClient returns a redirect-following-disabled client whose |
| 73 | // dialer enforces the SSRF rules. 3xx is treated as success-like |
| 74 | // (no-follow) since following is itself an SSRF amplifier. |
| 75 | func (c Config) HTTPClient() *http.Client { |
| 76 | cfg := c.applyDefaults() |
| 77 | tr := &http.Transport{ |
| 78 | DialContext: cfg.dialContext, |
| 79 | ResponseHeaderTimeout: cfg.RequestTimeout, |
| 80 | ForceAttemptHTTP2: false, |
| 81 | DisableKeepAlives: true, |
| 82 | } |
| 83 | return &http.Client{ |
| 84 | Transport: tr, |
| 85 | Timeout: cfg.RequestTimeout, |
| 86 | CheckRedirect: func(*http.Request, []*http.Request) error { |
| 87 | return http.ErrUseLastResponse |
| 88 | }, |
| 89 | } |
| 90 | } |
| 91 | |
| 92 | // ValidateWithResolve runs Validate plus a DNS resolve so callers can |
| 93 | // reject loopback/private/CGNAT/multicast hosts at create-time, not |
| 94 | // only at delivery-time. dialContext still re-resolves on each dial |
| 95 | // (DNS rebinding defense) — this is the cheap-but-thorough gate |
| 96 | // admin forms call so the persisted hook can't sit broken-on-arrival |
| 97 | // (SR2 H3). It's NOT a substitute for the dial-time check. |
| 98 | // |
| 99 | // A nil Resolver uses net.DefaultResolver. AllowedHosts (exact match |
| 100 | // case-insensitive) bypasses the IP block-list as in dialContext. |
| 101 | func (c Config) ValidateWithResolve(ctx context.Context, rawURL string) error { |
| 102 | if err := c.Validate(rawURL); err != nil { |
| 103 | return err |
| 104 | } |
| 105 | cfg := c.applyDefaults() |
| 106 | u, err := url.Parse(rawURL) |
| 107 | if err != nil { |
| 108 | return &Error{URL: rawURL, Reason: "malformed URL"} |
| 109 | } |
| 110 | host := u.Hostname() |
| 111 | if stringSetContainsFold(cfg.AllowedHosts, host) { |
| 112 | return nil |
| 113 | } |
| 114 | if cfg.AllowPrivateNetworks { |
| 115 | return nil |
| 116 | } |
| 117 | // IP literal — check directly without DNS. |
| 118 | if ip := net.ParseIP(host); ip != nil { |
| 119 | if IsForbiddenIP(ip) { |
| 120 | return &Error{URL: rawURL, Reason: "host resolves to forbidden IP " + ip.String()} |
| 121 | } |
| 122 | return nil |
| 123 | } |
| 124 | // Hostname — resolve and check every result. |
| 125 | resolver := cfg.Resolver |
| 126 | if resolver == nil { |
| 127 | resolver = net.DefaultResolver |
| 128 | } |
| 129 | ips, err := resolver.LookupIPAddr(ctx, host) |
| 130 | if err != nil { |
| 131 | return &Error{URL: rawURL, Reason: "DNS resolve: " + err.Error()} |
| 132 | } |
| 133 | if len(ips) == 0 { |
| 134 | return &Error{URL: rawURL, Reason: "no IPs resolved"} |
| 135 | } |
| 136 | for _, ipa := range ips { |
| 137 | if IsForbiddenIP(ipa.IP) { |
| 138 | return &Error{URL: rawURL, Reason: "host resolves to forbidden IP " + ipa.IP.String()} |
| 139 | } |
| 140 | } |
| 141 | return nil |
| 142 | } |
| 143 | |
| 144 | // Validate runs the syntactic gate (scheme, port, host shape) without |
| 145 | // resolving DNS. dialContext re-runs the IP-level checks at connect |
| 146 | // time so a passing Validate doesn't imply the URL is safe to dial — |
| 147 | // it's the early-rejection cheap gate. For full create-time rejection |
| 148 | // (loopback, private IPs, etc.) use ValidateWithResolve. |
| 149 | func (c Config) Validate(rawURL string) error { |
| 150 | cfg := c.applyDefaults() |
| 151 | u, err := url.Parse(rawURL) |
| 152 | if err != nil { |
| 153 | return &Error{URL: rawURL, Reason: "malformed URL"} |
| 154 | } |
| 155 | if !stringSetContains(cfg.AllowedSchemes, u.Scheme) { |
| 156 | return &Error{URL: rawURL, Reason: "scheme " + u.Scheme + " not allowed"} |
| 157 | } |
| 158 | host := u.Hostname() |
| 159 | if host == "" { |
| 160 | return &Error{URL: rawURL, Reason: "missing host"} |
| 161 | } |
| 162 | port := u.Port() |
| 163 | if port == "" { |
| 164 | switch u.Scheme { |
| 165 | case "http": |
| 166 | port = "80" |
| 167 | case "https": |
| 168 | port = "443" |
| 169 | } |
| 170 | } |
| 171 | pn, perr := strconv.Atoi(port) |
| 172 | if perr != nil || pn <= 0 || pn > 65535 { |
| 173 | return &Error{URL: rawURL, Reason: "invalid port"} |
| 174 | } |
| 175 | if !intSetContains(cfg.AllowedPorts, pn) { |
| 176 | return &Error{URL: rawURL, Reason: "port " + port + " not in allow-list"} |
| 177 | } |
| 178 | return nil |
| 179 | } |
| 180 | |
| 181 | func (c Config) dialContext(ctx context.Context, network, addr string) (net.Conn, error) { |
| 182 | host, port, err := net.SplitHostPort(addr) |
| 183 | if err != nil { |
| 184 | return nil, &Error{URL: addr, Reason: "split host:port: " + err.Error()} |
| 185 | } |
| 186 | pn, _ := strconv.Atoi(port) |
| 187 | if !intSetContains(c.AllowedPorts, pn) { |
| 188 | return nil, &Error{URL: addr, Reason: "port " + port + " not in allow-list"} |
| 189 | } |
| 190 | hostAllowed := stringSetContainsFold(c.AllowedHosts, host) |
| 191 | resolver := c.Resolver |
| 192 | if resolver == nil { |
| 193 | resolver = net.DefaultResolver |
| 194 | } |
| 195 | ips, err := resolver.LookupIPAddr(ctx, host) |
| 196 | if err != nil { |
| 197 | return nil, &Error{URL: addr, Reason: "DNS resolve: " + err.Error()} |
| 198 | } |
| 199 | if len(ips) == 0 { |
| 200 | return nil, &Error{URL: addr, Reason: "no IPs resolved"} |
| 201 | } |
| 202 | for _, ipa := range ips { |
| 203 | if !hostAllowed && !c.AllowPrivateNetworks && IsForbiddenIP(ipa.IP) { |
| 204 | return nil, &Error{URL: addr, Reason: "resolved to forbidden IP " + ipa.IP.String()} |
| 205 | } |
| 206 | } |
| 207 | dialer := &net.Dialer{Timeout: c.DialTimeout} |
| 208 | dialAddr := net.JoinHostPort(ips[0].IP.String(), port) |
| 209 | return dialer.DialContext(ctx, network, dialAddr) |
| 210 | } |
| 211 | |
| 212 | func (c Config) applyDefaults() Config { |
| 213 | def := Default() |
| 214 | if len(c.AllowedSchemes) == 0 { |
| 215 | c.AllowedSchemes = def.AllowedSchemes |
| 216 | } |
| 217 | if len(c.AllowedPorts) == 0 { |
| 218 | c.AllowedPorts = def.AllowedPorts |
| 219 | } |
| 220 | if c.DialTimeout == 0 { |
| 221 | c.DialTimeout = def.DialTimeout |
| 222 | } |
| 223 | if c.RequestTimeout == 0 { |
| 224 | c.RequestTimeout = def.RequestTimeout |
| 225 | } |
| 226 | return c |
| 227 | } |
| 228 | |
| 229 | // IsForbiddenIP reports whether the IP belongs to any of the ranges |
| 230 | // the spec marks as off-limits. Exposed so callers running their own |
| 231 | // validation (e.g. a config-time URL probe) can reuse the check. |
| 232 | func IsForbiddenIP(ip net.IP) bool { |
| 233 | if ip == nil { |
| 234 | return true |
| 235 | } |
| 236 | if ip.IsUnspecified() || ip.IsLoopback() || ip.IsLinkLocalUnicast() || |
| 237 | ip.IsLinkLocalMulticast() || ip.IsInterfaceLocalMulticast() || |
| 238 | ip.IsMulticast() { |
| 239 | return true |
| 240 | } |
| 241 | if ip4 := ip.To4(); ip4 != nil { |
| 242 | switch { |
| 243 | case ip4[0] == 10: |
| 244 | return true |
| 245 | case ip4[0] == 172 && ip4[1] >= 16 && ip4[1] <= 31: |
| 246 | return true |
| 247 | case ip4[0] == 192 && ip4[1] == 168: |
| 248 | return true |
| 249 | case ip4[0] == 100 && ip4[1] >= 64 && ip4[1] <= 127: |
| 250 | return true |
| 251 | case ip4[0] == 169 && ip4[1] == 254: |
| 252 | return true |
| 253 | case ip4[0] == 0: |
| 254 | return true |
| 255 | case ip4[0] == 255: |
| 256 | return true |
| 257 | } |
| 258 | return false |
| 259 | } |
| 260 | if len(ip) == net.IPv6len && (ip[0]&0xfe) == 0xfc { |
| 261 | return true |
| 262 | } |
| 263 | return false |
| 264 | } |
| 265 | |
| 266 | func stringSetContains(set []string, v string) bool { |
| 267 | for _, s := range set { |
| 268 | if s == v { |
| 269 | return true |
| 270 | } |
| 271 | } |
| 272 | return false |
| 273 | } |
| 274 | |
| 275 | func stringSetContainsFold(set []string, v string) bool { |
| 276 | for _, s := range set { |
| 277 | if equalFold(s, v) { |
| 278 | return true |
| 279 | } |
| 280 | } |
| 281 | return false |
| 282 | } |
| 283 | |
| 284 | func equalFold(a, b string) bool { |
| 285 | if len(a) != len(b) { |
| 286 | return false |
| 287 | } |
| 288 | for i := 0; i < len(a); i++ { |
| 289 | ca, cb := a[i], b[i] |
| 290 | if 'A' <= ca && ca <= 'Z' { |
| 291 | ca += 'a' - 'A' |
| 292 | } |
| 293 | if 'A' <= cb && cb <= 'Z' { |
| 294 | cb += 'a' - 'A' |
| 295 | } |
| 296 | if ca != cb { |
| 297 | return false |
| 298 | } |
| 299 | } |
| 300 | return true |
| 301 | } |
| 302 | |
| 303 | func intSetContains(set []int, v int) bool { |
| 304 | for _, s := range set { |
| 305 | if s == v { |
| 306 | return true |
| 307 | } |
| 308 | } |
| 309 | return false |
| 310 | } |
| 311 |