Go · 4234 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 package search
4
5 import (
6 "context"
7 "fmt"
8
9 "github.com/tenseleyFlow/shithub/internal/auth/policy"
10 )
11
12 // SearchCode runs a code search across paths and content. Visibility
13 // gates the underlying repo set; only repos the actor can read appear.
14 //
15 // We run two unioned subqueries:
16 //
17 // paths — `tsv @@ plainto_tsquery(...)` on the indexed path
18 // string. Always populated (size cap doesn't apply).
19 // content — `content_tsv @@ plainto_tsquery(...)` OR
20 // `content_trgm % $tsText` (trigram similarity for
21 // camelCase / snake_case identifiers).
22 //
23 // Path matches always rank above content matches at equal ts_rank.
24 // Within content matches, ts_rank dominates trigram similarity.
25 func SearchCode(ctx context.Context, deps Deps, actor policy.Actor, q ParsedQuery, limit, offset int) ([]CodeResult, int64, error) {
26 if !q.HasContent() {
27 return nil, 0, ErrEmptyQuery
28 }
29 tsText, tsCtor, hasFTS := tsQueryBindAndCtor(q)
30 if !hasFTS {
31 // Code search needs a textual hit. repo:-only narrows the
32 // repo set but we have nothing to match against; return
33 // empty rather than blast every indexed file.
34 return nil, 0, nil
35 }
36
37 args := []any{tsText}
38 visClause, visArgs := policy.VisibilityPredicate(actor, "r", len(args)+1)
39 args = append(args, visArgs...)
40
41 repoFilter := ""
42 if q.RepoFilter != nil {
43 ownerPos := len(args) + 1
44 namePos := len(args) + 2
45 args = append(args, q.RepoFilter.Owner, q.RepoFilter.Name)
46 repoFilter = repoFilterByOwnerName("r", ownerPos, namePos)
47 }
48
49 limPos := len(args) + 1
50 offPos := len(args) + 2
51 args = append(args, limit, offset)
52
53 // Path subquery: tsv match on the path string. We always rank
54 // path hits at +1.0 above content hits at the same ts_rank.
55 queryStr := fmt.Sprintf(`
56 WITH path_hits AS (
57 SELECT csp.repo_id, csp.ref_name, csp.path,
58 ts_rank_cd(csp.tsv, %[1]s('shithub_search', $1)) + 1.0 AS rank,
59 ''::text AS preview
60 FROM code_search_paths csp
61 JOIN repos r ON r.id = csp.repo_id
62 WHERE csp.tsv @@ %[1]s('shithub_search', $1)
63 AND %[2]s
64 %[3]s
65 ),
66 content_hits AS (
67 SELECT csc.repo_id, csc.ref_name, csc.path,
68 ts_rank_cd(csc.content_tsv, %[1]s('shithub_search', $1)) AS rank,
69 ''::text AS preview
70 FROM code_search_content csc
71 JOIN repos r ON r.id = csc.repo_id
72 WHERE csc.content_tsv @@ %[1]s('shithub_search', $1)
73 AND %[2]s
74 %[3]s
75 ),
76 all_hits AS (
77 SELECT * FROM path_hits
78 UNION ALL
79 SELECT * FROM content_hits
80 )
81 SELECT h.repo_id, %[6]s, r.name, h.ref_name, h.path, h.preview, h.rank
82 FROM all_hits h
83 JOIN repos r ON r.id = h.repo_id
84 %[7]s
85 ORDER BY h.rank DESC, h.path
86 LIMIT $%[4]d OFFSET $%[5]d
87 `, tsCtor, visClause, repoFilter, limPos, offPos, repoOwnerNameExpr("u", "o"), repoOwnerJoin("r", "u", "o"))
88
89 rows, err := deps.Pool.Query(ctx, queryStr, args...)
90 if err != nil {
91 return nil, 0, fmt.Errorf("search code: %w", err)
92 }
93 defer rows.Close()
94 out := make([]CodeResult, 0, limit)
95 for rows.Next() {
96 var r CodeResult
97 if err := rows.Scan(&r.RepoID, &r.OwnerUsername, &r.RepoName,
98 &r.RefName, &r.Path, &r.PreviewLine, &r.Rank); err != nil {
99 return nil, 0, err
100 }
101 out = append(out, r)
102 }
103 if err := rows.Err(); err != nil {
104 return nil, 0, err
105 }
106
107 // Total count: paths + content rows that matched. Repos with
108 // visibility filter applied. Pagination is approximate when
109 // the same path matches both indexes — we count the union
110 // honestly so the pager doesn't lie.
111 countQuery := fmt.Sprintf(`
112 SELECT (
113 SELECT count(*) FROM code_search_paths csp
114 JOIN repos r ON r.id = csp.repo_id
115 WHERE csp.tsv @@ %[1]s('shithub_search', $1)
116 AND %[2]s
117 %[3]s
118 ) + (
119 SELECT count(*) FROM code_search_content csc
120 JOIN repos r ON r.id = csc.repo_id
121 WHERE csc.content_tsv @@ %[1]s('shithub_search', $1)
122 AND %[2]s
123 %[3]s
124 )
125 `, tsCtor, visClause, repoFilter)
126 var total int64
127 if err := deps.Pool.QueryRow(ctx, countQuery, args[:len(args)-2]...).Scan(&total); err != nil {
128 return nil, 0, fmt.Errorf("count code: %w", err)
129 }
130 return out, total, nil
131 }
132