Go · 4371 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 package search
4
5 import (
6 "context"
7 "fmt"
8
9 "github.com/tenseleyFlow/shithub/internal/auth/policy"
10 )
11
12 // SearchCode runs a code search across paths and content. Visibility
13 // gates the underlying repo set; only repos the actor can read appear.
14 //
15 // We run two unioned subqueries:
16 //
17 // paths — `tsv @@ plainto_tsquery(...)` on the indexed path
18 // string. Always populated (size cap doesn't apply).
19 // content — `content_tsv @@ plainto_tsquery(...)` OR
20 // `content_trgm % $tsText` (trigram similarity for
21 // camelCase / snake_case identifiers).
22 //
23 // Path matches always rank above content matches at equal ts_rank.
24 // Within content matches, ts_rank dominates trigram similarity.
25 func SearchCode(ctx context.Context, deps Deps, actor policy.Actor, q ParsedQuery, limit, offset int) ([]CodeResult, int64, error) {
26 if !q.HasContent() {
27 return nil, 0, ErrEmptyQuery
28 }
29 tsText, tsCtor, hasFTS := tsQueryBindAndCtor(q)
30 if !hasFTS {
31 // Code search needs a textual hit. repo:-only narrows the
32 // repo set but we have nothing to match against; return
33 // empty rather than blast every indexed file.
34 return nil, 0, nil
35 }
36
37 args := []any{tsText}
38 visClause, visArgs := policy.VisibilityPredicate(actor, "r", len(args)+1)
39 args = append(args, visArgs...)
40
41 repoFilter := ""
42 if q.RepoFilter != nil {
43 ownerPos := len(args) + 1
44 namePos := len(args) + 2
45 args = append(args, q.RepoFilter.Owner, q.RepoFilter.Name)
46 repoFilter = fmt.Sprintf(
47 " AND r.id = (SELECT r2.id FROM repos r2 JOIN users u2 ON u2.id = r2.owner_user_id "+
48 "WHERE u2.username = $%d AND r2.name = $%d AND r2.deleted_at IS NULL)",
49 ownerPos, namePos,
50 )
51 }
52
53 limPos := len(args) + 1
54 offPos := len(args) + 2
55 args = append(args, limit, offset)
56
57 // Path subquery: tsv match on the path string. We always rank
58 // path hits at +1.0 above content hits at the same ts_rank.
59 queryStr := fmt.Sprintf(`
60 WITH path_hits AS (
61 SELECT csp.repo_id, csp.ref_name, csp.path,
62 ts_rank_cd(csp.tsv, %[1]s('shithub_search', $1)) + 1.0 AS rank,
63 ''::text AS preview
64 FROM code_search_paths csp
65 JOIN repos r ON r.id = csp.repo_id
66 WHERE csp.tsv @@ %[1]s('shithub_search', $1)
67 AND %[2]s
68 %[3]s
69 ),
70 content_hits AS (
71 SELECT csc.repo_id, csc.ref_name, csc.path,
72 ts_rank_cd(csc.content_tsv, %[1]s('shithub_search', $1)) AS rank,
73 ''::text AS preview
74 FROM code_search_content csc
75 JOIN repos r ON r.id = csc.repo_id
76 WHERE csc.content_tsv @@ %[1]s('shithub_search', $1)
77 AND %[2]s
78 %[3]s
79 ),
80 all_hits AS (
81 SELECT * FROM path_hits
82 UNION ALL
83 SELECT * FROM content_hits
84 )
85 SELECT h.repo_id, u.username, r.name, h.ref_name, h.path, h.preview, h.rank
86 FROM all_hits h
87 JOIN repos r ON r.id = h.repo_id
88 JOIN users u ON u.id = r.owner_user_id
89 ORDER BY h.rank DESC, h.path
90 LIMIT $%[4]d OFFSET $%[5]d
91 `, tsCtor, visClause, repoFilter, limPos, offPos)
92
93 rows, err := deps.Pool.Query(ctx, queryStr, args...)
94 if err != nil {
95 return nil, 0, fmt.Errorf("search code: %w", err)
96 }
97 defer rows.Close()
98 out := make([]CodeResult, 0, limit)
99 for rows.Next() {
100 var r CodeResult
101 if err := rows.Scan(&r.RepoID, &r.OwnerUsername, &r.RepoName,
102 &r.RefName, &r.Path, &r.PreviewLine, &r.Rank); err != nil {
103 return nil, 0, err
104 }
105 out = append(out, r)
106 }
107 if err := rows.Err(); err != nil {
108 return nil, 0, err
109 }
110
111 // Total count: paths + content rows that matched. Repos with
112 // visibility filter applied. Pagination is approximate when
113 // the same path matches both indexes — we count the union
114 // honestly so the pager doesn't lie.
115 countQuery := fmt.Sprintf(`
116 SELECT (
117 SELECT count(*) FROM code_search_paths csp
118 JOIN repos r ON r.id = csp.repo_id
119 WHERE csp.tsv @@ %[1]s('shithub_search', $1)
120 AND %[2]s
121 %[3]s
122 ) + (
123 SELECT count(*) FROM code_search_content csc
124 JOIN repos r ON r.id = csc.repo_id
125 WHERE csc.content_tsv @@ %[1]s('shithub_search', $1)
126 AND %[2]s
127 %[3]s
128 )
129 `, tsCtor, visClause, repoFilter)
130 var total int64
131 if err := deps.Pool.QueryRow(ctx, countQuery, args[:len(args)-2]...).Scan(&total); err != nil {
132 return nil, 0, fmt.Errorf("count code: %w", err)
133 }
134 return out, total, nil
135 }
136