| 1 | // SPDX-License-Identifier: AGPL-3.0-or-later |
| 2 | |
| 3 | package session |
| 4 | |
| 5 | import ( |
| 6 | "net/http" |
| 7 | "net/http/httptest" |
| 8 | "strings" |
| 9 | "testing" |
| 10 | "time" |
| 11 | ) |
| 12 | |
| 13 | func newStore(t *testing.T) *CookieStore { |
| 14 | t.Helper() |
| 15 | key, err := GenerateKey() |
| 16 | if err != nil { |
| 17 | t.Fatalf("GenerateKey: %v", err) |
| 18 | } |
| 19 | store, err := NewCookieStore(CookieStoreConfig{Key: key}) |
| 20 | if err != nil { |
| 21 | t.Fatalf("NewCookieStore: %v", err) |
| 22 | } |
| 23 | return store |
| 24 | } |
| 25 | |
| 26 | func TestCookieStore_SaveLoadRoundTrip(t *testing.T) { |
| 27 | t.Parallel() |
| 28 | store := newStore(t) |
| 29 | |
| 30 | rec := httptest.NewRecorder() |
| 31 | saved := &Session{UserID: 42, Theme: "dark"} |
| 32 | saved.AddFlash("welcome back") |
| 33 | if err := store.Save(rec, httptest.NewRequest(http.MethodGet, "/", nil), saved); err != nil { |
| 34 | t.Fatalf("Save: %v", err) |
| 35 | } |
| 36 | cookies := rec.Result().Cookies() |
| 37 | if len(cookies) != 1 || cookies[0].Name != CookieName { |
| 38 | t.Fatalf("expected one %s cookie, got %v", CookieName, cookies) |
| 39 | } |
| 40 | |
| 41 | // Round-trip via a fresh request carrying the cookie. |
| 42 | req := httptest.NewRequest(http.MethodGet, "/", nil) |
| 43 | req.AddCookie(cookies[0]) |
| 44 | loaded, err := store.Load(req) |
| 45 | if err != nil { |
| 46 | t.Fatalf("Load: %v", err) |
| 47 | } |
| 48 | if loaded.UserID != 42 { |
| 49 | t.Errorf("UserID: got %d, want 42", loaded.UserID) |
| 50 | } |
| 51 | if loaded.Theme != "dark" { |
| 52 | t.Errorf("Theme: got %q, want dark", loaded.Theme) |
| 53 | } |
| 54 | flashes := loaded.PopFlashes() |
| 55 | if len(flashes) != 1 || flashes[0] != "welcome back" { |
| 56 | t.Errorf("Flashes: got %v", flashes) |
| 57 | } |
| 58 | } |
| 59 | |
| 60 | func TestCookieStore_TamperedCookieYieldsEmptySession(t *testing.T) { |
| 61 | t.Parallel() |
| 62 | store := newStore(t) |
| 63 | req := httptest.NewRequest(http.MethodGet, "/", nil) |
| 64 | //nolint:gosec // G124: test fixture intentionally constructs a malformed cookie to verify Load tolerates it. |
| 65 | req.AddCookie(&http.Cookie{Name: CookieName, Value: "this-is-not-a-valid-aead-payload"}) |
| 66 | loaded, err := store.Load(req) |
| 67 | if err != nil { |
| 68 | t.Fatalf("Load: %v", err) |
| 69 | } |
| 70 | if !loaded.IsAnonymous() || loaded.UserID != 0 { |
| 71 | t.Errorf("expected anonymous session, got %+v", loaded) |
| 72 | } |
| 73 | } |
| 74 | |
| 75 | func TestCookieStore_ExpiredSessionYieldsEmpty(t *testing.T) { |
| 76 | t.Parallel() |
| 77 | store := newStore(t) |
| 78 | store.maxAge = 1 * time.Second |
| 79 | store.clock = func() time.Time { return time.Unix(1_000_000, 0) } |
| 80 | |
| 81 | rec := httptest.NewRecorder() |
| 82 | if err := store.Save(rec, httptest.NewRequest(http.MethodGet, "/", nil), &Session{UserID: 7}); err != nil { |
| 83 | t.Fatalf("Save: %v", err) |
| 84 | } |
| 85 | cookie := rec.Result().Cookies()[0] |
| 86 | |
| 87 | // Advance the clock past expiry. |
| 88 | store.clock = func() time.Time { return time.Unix(1_000_000+10, 0) } |
| 89 | req := httptest.NewRequest(http.MethodGet, "/", nil) |
| 90 | req.AddCookie(cookie) |
| 91 | loaded, _ := store.Load(req) |
| 92 | if !loaded.IsAnonymous() { |
| 93 | t.Errorf("expected expired session to be anonymous, got %+v", loaded) |
| 94 | } |
| 95 | } |
| 96 | |
| 97 | func TestCookieStore_KeySizeEnforced(t *testing.T) { |
| 98 | t.Parallel() |
| 99 | if _, err := NewCookieStore(CookieStoreConfig{Key: []byte("too-short")}); err == nil { |
| 100 | t.Errorf("expected error for short key") |
| 101 | } |
| 102 | } |
| 103 | |
| 104 | func TestCookieStore_ClearDeletesCookie(t *testing.T) { |
| 105 | t.Parallel() |
| 106 | store := newStore(t) |
| 107 | rec := httptest.NewRecorder() |
| 108 | store.Clear(rec) |
| 109 | cookies := rec.Result().Cookies() |
| 110 | if len(cookies) != 1 { |
| 111 | t.Fatalf("expected one cookie, got %d", len(cookies)) |
| 112 | } |
| 113 | if cookies[0].MaxAge >= 0 { |
| 114 | t.Errorf("Clear cookie MaxAge: got %d, want negative", cookies[0].MaxAge) |
| 115 | } |
| 116 | header := rec.Header().Get("Set-Cookie") |
| 117 | if !strings.Contains(header, CookieName) { |
| 118 | t.Errorf("Set-Cookie missing %s: %q", CookieName, header) |
| 119 | } |
| 120 | } |
| 121 |