| 1 | // SPDX-License-Identifier: AGPL-3.0-or-later |
| 2 | |
| 3 | package handlers |
| 4 | |
| 5 | import ( |
| 6 | "net/http" |
| 7 | "net/http/httptest" |
| 8 | "testing" |
| 9 | ) |
| 10 | |
| 11 | func TestPublicBaseURLRejectsUnsafeRequestHostFallback(t *testing.T) { |
| 12 | t.Parallel() |
| 13 | |
| 14 | req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil) |
| 15 | req.Host = "example.com\r\nSitemap: https://evil.test/sitemap.xml" |
| 16 | if got := publicBaseURL("", req); got != "" { |
| 17 | t.Fatalf("publicBaseURL accepted unsafe host = %q", got) |
| 18 | } |
| 19 | } |
| 20 | |
| 21 | func TestPublicBaseURLPrefersConfiguredBase(t *testing.T) { |
| 22 | t.Parallel() |
| 23 | |
| 24 | req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil) |
| 25 | req.Host = "untrusted.example" |
| 26 | if got, want := publicBaseURL("https://shithub.sh/", req), "https://shithub.sh"; got != want { |
| 27 | t.Fatalf("publicBaseURL = %q, want %q", got, want) |
| 28 | } |
| 29 | } |
| 30 |