Go · 2071 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 package webhook
4
5 import (
6 "context"
7 "errors"
8 "fmt"
9 "testing"
10 )
11
12 // TestSSRFRejectsLoopbackAndPrivateAtCreateTime pins SR2 H3:
13 // Create/Update must reject loopback / RFC1918 / disallowed-port URLs
14 // synchronously instead of letting them persist and only fail on
15 // every delivery attempt.
16 //
17 // Production code calls cfg.ValidateWithResolve(ctx, url) — that's
18 // what Create() and Update() use. The plain Validate() is the cheap
19 // syntactic gate; ValidateWithResolve adds the IP-block-list check
20 // that catches loopback + RFC1918 hosts (literal and resolved).
21 func TestSSRFRejectsLoopbackAndPrivateAtCreateTime(t *testing.T) {
22 t.Parallel()
23
24 cfg := DefaultSSRFConfig()
25 rejected := []string{
26 "http://127.0.0.1/hook",
27 "http://127.0.0.1:8080/hook",
28 "http://[::1]/hook",
29 "http://192.168.1.1/hook",
30 "http://10.0.0.1/hook",
31 "http://172.16.0.1/hook",
32 // Disallowed port (only 80/443/8080/8443 pass by default).
33 "http://example.com:9090/hook",
34 }
35 ctx := context.Background()
36 for _, u := range rejected {
37 t.Run(u, func(t *testing.T) {
38 t.Parallel()
39 err := cfg.ValidateWithResolve(ctx, u)
40 if err == nil {
41 t.Fatalf("SSRF.ValidateWithResolve(%q) = nil; expected an error", u)
42 }
43 })
44 }
45
46 // Sanity: a public-looking URL on a default port should pass.
47 if err := cfg.ValidateWithResolve(ctx, "https://example.com/hook"); err != nil {
48 t.Fatalf("SSRF.ValidateWithResolve(public) = %v; expected nil", err)
49 }
50 }
51
52 // TestCreateUpdateWrapsSSRFErrInBadURL pins the error contract used
53 // by Create/Update: an SSRF rejection wraps ErrBadURL so callers can
54 // errors.Is(err, ErrBadURL) for form-shaped feedback. Production code
55 // uses fmt.Errorf("%w: %v", ErrBadURL, err) — this test pins that the
56 // wrap shape unwraps correctly.
57 func TestCreateUpdateWrapsSSRFErrInBadURL(t *testing.T) {
58 t.Parallel()
59
60 inner := errors.New("ssrf: loopback")
61 wrapped := fmt.Errorf("%w: %v", ErrBadURL, inner)
62
63 if !errors.Is(wrapped, ErrBadURL) {
64 t.Fatalf("wrapped error is not ErrBadURL: %v", wrapped)
65 }
66 }
67