Go · 13456 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 // Package search wires the S28 web search surface. The full results
4 // page lives at GET /search; the nav quick dropdown lives at GET
5 // /search/quick.
6 package search
7
8 import (
9 "context"
10 "errors"
11 "log/slog"
12 "net/http"
13 "net/url"
14 "time"
15
16 "github.com/go-chi/chi/v5"
17 "github.com/jackc/pgx/v5/pgxpool"
18
19 "github.com/tenseleyFlow/shithub/internal/auth/policy"
20 "github.com/tenseleyFlow/shithub/internal/ratelimit"
21 srch "github.com/tenseleyFlow/shithub/internal/search"
22 "github.com/tenseleyFlow/shithub/internal/web/middleware"
23 "github.com/tenseleyFlow/shithub/internal/web/render"
24 )
25
26 // Deps wires the handler set.
27 type Deps struct {
28 Logger *slog.Logger
29 Render *render.Renderer
30 Pool *pgxpool.Pool
31 // Limiter, when non-nil, gates /search per-(viewer or IP). Audit
32 // 2026-05-10 H4: search renders amplify FTS cost 5×–6× per
33 // request, so without a limiter a single client can hammer the
34 // DB. Optional in tests; required in production wiring.
35 Limiter *ratelimit.Limiter
36 }
37
38 // Handlers is the registered handler set. Construct via New.
39 type Handlers struct {
40 d Deps
41 tabsCache *tabsCache // nil-safe — Mount constructs it
42 }
43
44 // New constructs the handler set, validating Deps.
45 func New(d Deps) (*Handlers, error) {
46 if d.Render == nil {
47 return nil, errors.New("search: nil Render")
48 }
49 if d.Pool == nil {
50 return nil, errors.New("search: nil Pool")
51 }
52 return &Handlers{d: d, tabsCache: newTabsCache()}, nil
53 }
54
55 // SearchRateLimitPolicy is the per-(viewer or IP) limit applied to
56 // /search and /search/quick. 60/min is generous for human use
57 // (typical browse rate is well under this) but cheap to defeat any
58 // query-rotation attack that bypasses the tab-count cache (audit
59 // 2026-05-10 H4+H5). Surfaced as a var so tests can tighten it.
60 var SearchRateLimitPolicy = ratelimit.Policy{
61 Scope: "search",
62 Max: 60,
63 Window: 1 * time.Minute,
64 }
65
66 // Mount registers /search and /search/quick. When d.Limiter is set,
67 // both routes go through the rate-limit middleware before reaching
68 // the handlers — protects the FTS path from query-rotation attacks
69 // that the tab-counts cache alone can't absorb.
70 func (h *Handlers) Mount(r chi.Router) {
71 if h.d.Limiter != nil {
72 r.Group(func(r chi.Router) {
73 r.Use(h.d.Limiter.Middleware(SearchRateLimitPolicy, searchRateLimitKey))
74 r.Get("/search", h.results)
75 r.Get("/search/quick", h.quick)
76 })
77 return
78 }
79 r.Get("/search", h.results)
80 r.Get("/search/quick", h.quick)
81 }
82
83 // searchRateLimitKey picks the per-request key. Authed users key
84 // on user_id (so an attacker can't bypass by hopping accounts they
85 // don't have); anonymous users key on the trusted client IP. We
86 // trust X-Forwarded-For only when middleware.RealIP has already
87 // vetted it, which it does at the global stack level.
88 func searchRateLimitKey(r *http.Request) string {
89 viewer := middleware.CurrentUserFromContext(r.Context())
90 if !viewer.IsAnonymous() {
91 return "u:" + intString(int(viewer.ID))
92 }
93 if ip, ok := ratelimit.ClientIP(r, true); ok {
94 return "ip:" + ip.String()
95 }
96 return ""
97 }
98
99 func (h *Handlers) deps() srch.Deps {
100 return srch.Deps{Pool: h.d.Pool, Logger: h.d.Logger}
101 }
102
103 func (h *Handlers) actor(r *http.Request) policy.Actor {
104 viewer := middleware.CurrentUserFromContext(r.Context())
105 if viewer.IsAnonymous() {
106 return policy.AnonymousActor()
107 }
108 return viewer.PolicyActor()
109 }
110
111 // results renders the full /search page with type tabs.
112 func (h *Handlers) results(w http.ResponseWriter, r *http.Request) {
113 rawQ := r.URL.Query().Get("q")
114 tab := normalizeSearchTab(r.URL.Query().Get("type"))
115 page := pageFromRequest(r)
116
117 parsed := srch.ParseQuery(rawQ)
118 actor := h.actor(r)
119 deps := h.deps()
120
121 data := map[string]any{
122 "Title": "Search",
123 "Query": rawQ,
124 "GlobalSearchQuery": rawQ,
125 "Tab": tab,
126 "Page": page,
127 "Parsed": parsed,
128 "PageSize": srch.PageSize,
129 "SearchProTip": searchProTip(tab),
130 }
131
132 if !parsed.HasContent() {
133 data["EmptyQuery"] = true
134 data["SearchTabs"] = h.searchTabs(r, actor, parsed, rawQ, tab)
135 _ = h.d.Render.RenderPage(w, r, "search/results", data)
136 return
137 }
138
139 offset := (page - 1) * srch.PageSize
140 switch tab {
141 case "repositories":
142 rows, total, err := srch.SearchRepos(r.Context(), deps, actor, parsed, srch.PageSize, offset)
143 if err != nil && !errors.Is(err, srch.ErrEmptyQuery) {
144 h.d.Logger.ErrorContext(r.Context(), "search repos", "error", err)
145 }
146 data["Repos"] = rows
147 data["Total"] = total
148 data["HasNext"] = int64(page*srch.PageSize) < total
149 case "issues":
150 rows, total, err := srch.SearchIssues(r.Context(), deps, actor, parsed, "issue", srch.PageSize, offset)
151 if err != nil && !errors.Is(err, srch.ErrEmptyQuery) {
152 h.d.Logger.ErrorContext(r.Context(), "search issues", "error", err)
153 }
154 data["Issues"] = rows
155 data["Total"] = total
156 data["HasNext"] = int64(page*srch.PageSize) < total
157 case "pullrequests":
158 rows, total, err := srch.SearchIssues(r.Context(), deps, actor, parsed, "pr", srch.PageSize, offset)
159 if err != nil && !errors.Is(err, srch.ErrEmptyQuery) {
160 h.d.Logger.ErrorContext(r.Context(), "search pulls", "error", err)
161 }
162 data["Issues"] = rows
163 data["Total"] = total
164 data["HasNext"] = int64(page*srch.PageSize) < total
165 case "users":
166 rows, total, err := srch.SearchUsers(r.Context(), deps, parsed, srch.PageSize, offset)
167 if err != nil && !errors.Is(err, srch.ErrEmptyQuery) {
168 h.d.Logger.ErrorContext(r.Context(), "search users", "error", err)
169 }
170 data["Users"] = rows
171 data["Total"] = total
172 data["HasNext"] = int64(page*srch.PageSize) < total
173 case "code":
174 rows, total, err := srch.SearchCode(r.Context(), deps, actor, parsed, srch.PageSize, offset)
175 if err != nil && !errors.Is(err, srch.ErrEmptyQuery) {
176 h.d.Logger.ErrorContext(r.Context(), "search code", "error", err)
177 }
178 data["Code"] = rows
179 data["Total"] = total
180 data["HasNext"] = int64(page*srch.PageSize) < total
181 default:
182 // Unknown tab → render the page with the empty-state shape
183 // rather than 400 (a typo in the URL shouldn't be a hard
184 // error).
185 data["EmptyQuery"] = true
186 }
187 data["HasPrev"] = page > 1
188 data["SearchTabs"] = h.searchTabs(r, actor, parsed, rawQ, tab)
189 data["ResultHeading"] = searchResultHeading(tab, data["Total"])
190 if page > 1 {
191 data["PrevHref"] = searchHref(rawQ, tab, page-1)
192 }
193 if next, ok := data["HasNext"].(bool); ok && next {
194 data["NextHref"] = searchHref(rawQ, tab, page+1)
195 }
196
197 if err := h.d.Render.RenderPage(w, r, "search/results", data); err != nil {
198 h.d.Logger.ErrorContext(r.Context(), "search render", "error", err)
199 }
200 }
201
202 func normalizeSearchTab(tab string) string {
203 switch tab {
204 case "", "repos", "repositories":
205 return "repositories"
206 case "code":
207 return "code"
208 case "issues":
209 return "issues"
210 case "pulls", "pullrequests":
211 return "pullrequests"
212 case "users":
213 return "users"
214 default:
215 return "repositories"
216 }
217 }
218
219 type searchTab struct {
220 Key string
221 Label string
222 Icon string
223 Count int64
224 Href string
225 Selected bool
226 }
227
228 func (h *Handlers) searchTabs(r *http.Request, actor policy.Actor, parsed srch.ParsedQuery, rawQ, active string) []searchTab {
229 tabs := []searchTab{
230 {Key: "code", Label: "Code", Icon: "code"},
231 {Key: "repositories", Label: "Repositories", Icon: "repo"},
232 {Key: "issues", Label: "Issues", Icon: "issue-opened"},
233 {Key: "pullrequests", Label: "Pull requests", Icon: "git-pull-request"},
234 {Key: "users", Label: "Users", Icon: "people"},
235 }
236 for i := range tabs {
237 tabs[i].Selected = tabs[i].Key == active
238 tabs[i].Href = searchHref(rawQ, tabs[i].Key, 1)
239 }
240 if !parsed.HasContent() {
241 return tabs
242 }
243
244 // Counts are cached per-(query, viewer) for tabsCacheTTL. The
245 // active-tab's actual result rows are NOT cached here — only the
246 // 5 count-only badge calls that pre-fix were the dominant cost
247 // (audit 2026-05-10 H5). Single-flighted via lru.Group so a
248 // thundering-herd on the same key doesn't spawn N waves.
249 key := tabsCacheKey{q: canonicalizeQuery(parsed), userID: actorUserID(actor)}
250 cached, err := h.tabsCache.g.Do(r.Context(), key, func(ctx context.Context) ([]searchTab, error) {
251 return h.computeTabCounts(ctx, actor, parsed), nil
252 })
253 if err != nil {
254 // Group.Do never caches errors and our fetch returns nil; this
255 // path is unreachable today but kept for defensiveness.
256 h.d.Logger.ErrorContext(r.Context(), "search tabs cache", "error", err)
257 cached = h.computeTabCounts(r.Context(), actor, parsed)
258 }
259 // Merge cached counts into the freshly-built (Selected/Href-aware)
260 // tabs slice. The cached value carries Counts and the same Key
261 // ordering; everything else is per-request and not cached.
262 for i := range tabs {
263 for j := range cached {
264 if cached[j].Key == tabs[i].Key {
265 tabs[i].Count = cached[j].Count
266 break
267 }
268 }
269 }
270 return tabs
271 }
272
273 // computeTabCounts is the cache miss path: 5 FTS count-only queries.
274 // Returned slice carries (Key, Count) only — Selected/Href/Label/
275 // Icon are per-request and applied by the caller.
276 func (h *Handlers) computeTabCounts(ctx context.Context, actor policy.Actor, parsed srch.ParsedQuery) []searchTab {
277 deps := h.deps()
278 out := []searchTab{
279 {Key: "code"},
280 {Key: "repositories"},
281 {Key: "issues"},
282 {Key: "pullrequests"},
283 {Key: "users"},
284 }
285 for i := range out {
286 var total int64
287 var err error
288 switch out[i].Key {
289 case "repositories":
290 _, total, err = srch.SearchRepos(ctx, deps, actor, parsed, 0, 0)
291 case "code":
292 _, total, err = srch.SearchCode(ctx, deps, actor, parsed, 0, 0)
293 case "issues":
294 _, total, err = srch.SearchIssues(ctx, deps, actor, parsed, "issue", 0, 0)
295 case "pullrequests":
296 _, total, err = srch.SearchIssues(ctx, deps, actor, parsed, "pr", 0, 0)
297 case "users":
298 _, total, err = srch.SearchUsers(ctx, deps, parsed, 0, 0)
299 }
300 if err != nil && !errors.Is(err, srch.ErrEmptyQuery) {
301 h.d.Logger.ErrorContext(ctx, "search tab count", "tab", out[i].Key, "error", err)
302 continue
303 }
304 out[i].Count = total
305 }
306 return out
307 }
308
309 // actorUserID returns 0 for anonymous, the user_id otherwise. Used
310 // as the (anon vs each-authed-user) discriminant in the tabs cache
311 // key — anonymous viewers all see the same public-only result set
312 // so they share a slot; authed viewers see private results based
313 // on their collab roles, so each gets their own.
314 func actorUserID(a policy.Actor) int64 {
315 if a.IsAnonymous {
316 return 0
317 }
318 return a.UserID
319 }
320
321 func searchHref(q, tab string, page int) string {
322 v := url.Values{}
323 v.Set("q", q)
324 v.Set("type", tab)
325 if page > 1 {
326 v.Set("p", intString(page))
327 }
328 return "/search?" + v.Encode()
329 }
330
331 func searchResultHeading(tab string, total any) string {
332 count, _ := total.(int64)
333 switch tab {
334 case "code":
335 return plural(count, "code result", "code results")
336 case "issues":
337 return plural(count, "issue result", "issue results")
338 case "pullrequests":
339 return plural(count, "pull request result", "pull request results")
340 case "users":
341 return plural(count, "user result", "user results")
342 default:
343 return plural(count, "repository result", "repository results")
344 }
345 }
346
347 func plural(count int64, one, many string) string {
348 if count == 1 {
349 return "1 " + one
350 }
351 return int64String(count) + " " + many
352 }
353
354 func int64String(n int64) string {
355 if n == 0 {
356 return "0"
357 }
358 var buf [20]byte
359 i := len(buf)
360 for n > 0 {
361 i--
362 buf[i] = byte('0' + n%10)
363 n /= 10
364 }
365 return string(buf[i:])
366 }
367
368 func searchProTip(tab string) string {
369 switch tab {
370 case "issues", "pullrequests":
371 return "Restrict your search to the title by using the in:title qualifier."
372 case "code":
373 return "Use repo:owner/name to limit code search to a single repository."
374 case "users":
375 return "Search by username or display name to find people faster."
376 default:
377 return "Press / to activate the search input again and adjust your query."
378 }
379 }
380
381 // quick is the nav dropdown endpoint. Returns one fragment with
382 // the top N results across the implemented quick-search types.
383 func (h *Handlers) quick(w http.ResponseWriter, r *http.Request) {
384 rawQ := r.URL.Query().Get("q")
385 parsed := srch.ParseQuery(rawQ)
386 if !parsed.HasContent() {
387 w.WriteHeader(http.StatusNoContent)
388 return
389 }
390 actor := h.actor(r)
391 deps := h.deps()
392
393 repos, _, _ := srch.SearchRepos(r.Context(), deps, actor, parsed, srch.QuickResultsLimit, 0)
394 issues, _, _ := srch.SearchIssues(r.Context(), deps, actor, parsed, "", srch.QuickResultsLimit, 0)
395 users, _, _ := srch.SearchUsers(r.Context(), deps, parsed, srch.QuickResultsLimit, 0)
396
397 data := map[string]any{
398 "Query": rawQ,
399 "SearchHref": searchHref(rawQ, "repositories", 1),
400 "Repos": repos,
401 "Issues": issues,
402 "Users": users,
403 }
404 if err := h.d.Render.RenderFragment(w, "search/quick_dropdown", data); err != nil {
405 h.d.Logger.ErrorContext(r.Context(), "quick render", "error", err)
406 }
407 }
408
409 // pageFromRequest pulls ?page=N, defaulting to 1 on missing/invalid.
410 func pageFromRequest(r *http.Request) int {
411 p := r.URL.Query().Get("p")
412 if p == "" {
413 p = r.URL.Query().Get("page")
414 }
415 if p == "" {
416 return 1
417 }
418 n := 0
419 for _, c := range p {
420 if c < '0' || c > '9' {
421 return 1
422 }
423 n = n*10 + int(c-'0')
424 if n > 10000 {
425 return 1
426 }
427 }
428 if n < 1 {
429 return 1
430 }
431 return n
432 }
433
434 func intString(n int) string {
435 if n == 0 {
436 return "0"
437 }
438 var buf [20]byte
439 i := len(buf)
440 for n > 0 {
441 i--
442 buf[i] = byte('0' + n%10)
443 n /= 10
444 }
445 return string(buf[i:])
446 }
447