| 1 | // SPDX-License-Identifier: AGPL-3.0-or-later |
| 2 | |
| 3 | package search |
| 4 | |
| 5 | import ( |
| 6 | "context" |
| 7 | "fmt" |
| 8 | |
| 9 | "github.com/tenseleyFlow/shithub/internal/auth/policy" |
| 10 | ) |
| 11 | |
| 12 | // SearchCode runs a code search across paths and content. Visibility |
| 13 | // gates the underlying repo set; only repos the actor can read appear. |
| 14 | // |
| 15 | // We run two unioned subqueries: |
| 16 | // |
| 17 | // paths — `tsv @@ plainto_tsquery(...)` on the indexed path |
| 18 | // string. Always populated (size cap doesn't apply). |
| 19 | // content — `content_tsv @@ plainto_tsquery(...)` OR |
| 20 | // `content_trgm % $tsText` (trigram similarity for |
| 21 | // camelCase / snake_case identifiers). |
| 22 | // |
| 23 | // Path matches always rank above content matches at equal ts_rank. |
| 24 | // Within content matches, ts_rank dominates trigram similarity. |
| 25 | func SearchCode(ctx context.Context, deps Deps, actor policy.Actor, q ParsedQuery, limit, offset int) ([]CodeResult, int64, error) { |
| 26 | if !q.HasContent() { |
| 27 | return nil, 0, ErrEmptyQuery |
| 28 | } |
| 29 | tsText, tsCtor, hasFTS := tsQueryBindAndCtor(q) |
| 30 | if !hasFTS { |
| 31 | // Code search needs a textual hit. repo:-only narrows the |
| 32 | // repo set but we have nothing to match against; return |
| 33 | // empty rather than blast every indexed file. |
| 34 | return nil, 0, nil |
| 35 | } |
| 36 | |
| 37 | args := []any{tsText} |
| 38 | visClause, visArgs := policy.VisibilityPredicate(actor, "r", len(args)+1) |
| 39 | args = append(args, visArgs...) |
| 40 | |
| 41 | repoFilter := "" |
| 42 | if q.RepoFilter != nil { |
| 43 | ownerPos := len(args) + 1 |
| 44 | namePos := len(args) + 2 |
| 45 | args = append(args, q.RepoFilter.Owner, q.RepoFilter.Name) |
| 46 | repoFilter = fmt.Sprintf( |
| 47 | " AND r.id = (SELECT r2.id FROM repos r2 JOIN users u2 ON u2.id = r2.owner_user_id "+ |
| 48 | "WHERE u2.username = $%d AND r2.name = $%d AND r2.deleted_at IS NULL)", |
| 49 | ownerPos, namePos, |
| 50 | ) |
| 51 | } |
| 52 | |
| 53 | limPos := len(args) + 1 |
| 54 | offPos := len(args) + 2 |
| 55 | args = append(args, limit, offset) |
| 56 | |
| 57 | // Path subquery: tsv match on the path string. We always rank |
| 58 | // path hits at +1.0 above content hits at the same ts_rank. |
| 59 | queryStr := fmt.Sprintf(` |
| 60 | WITH path_hits AS ( |
| 61 | SELECT csp.repo_id, csp.ref_name, csp.path, |
| 62 | ts_rank_cd(csp.tsv, %[1]s('shithub_search', $1)) + 1.0 AS rank, |
| 63 | ''::text AS preview |
| 64 | FROM code_search_paths csp |
| 65 | JOIN repos r ON r.id = csp.repo_id |
| 66 | WHERE csp.tsv @@ %[1]s('shithub_search', $1) |
| 67 | AND %[2]s |
| 68 | %[3]s |
| 69 | ), |
| 70 | content_hits AS ( |
| 71 | SELECT csc.repo_id, csc.ref_name, csc.path, |
| 72 | ts_rank_cd(csc.content_tsv, %[1]s('shithub_search', $1)) AS rank, |
| 73 | ''::text AS preview |
| 74 | FROM code_search_content csc |
| 75 | JOIN repos r ON r.id = csc.repo_id |
| 76 | WHERE csc.content_tsv @@ %[1]s('shithub_search', $1) |
| 77 | AND %[2]s |
| 78 | %[3]s |
| 79 | ), |
| 80 | all_hits AS ( |
| 81 | SELECT * FROM path_hits |
| 82 | UNION ALL |
| 83 | SELECT * FROM content_hits |
| 84 | ) |
| 85 | SELECT h.repo_id, u.username, r.name, h.ref_name, h.path, h.preview, h.rank |
| 86 | FROM all_hits h |
| 87 | JOIN repos r ON r.id = h.repo_id |
| 88 | JOIN users u ON u.id = r.owner_user_id |
| 89 | ORDER BY h.rank DESC, h.path |
| 90 | LIMIT $%[4]d OFFSET $%[5]d |
| 91 | `, tsCtor, visClause, repoFilter, limPos, offPos) |
| 92 | |
| 93 | rows, err := deps.Pool.Query(ctx, queryStr, args...) |
| 94 | if err != nil { |
| 95 | return nil, 0, fmt.Errorf("search code: %w", err) |
| 96 | } |
| 97 | defer rows.Close() |
| 98 | out := make([]CodeResult, 0, limit) |
| 99 | for rows.Next() { |
| 100 | var r CodeResult |
| 101 | if err := rows.Scan(&r.RepoID, &r.OwnerUsername, &r.RepoName, |
| 102 | &r.RefName, &r.Path, &r.PreviewLine, &r.Rank); err != nil { |
| 103 | return nil, 0, err |
| 104 | } |
| 105 | out = append(out, r) |
| 106 | } |
| 107 | if err := rows.Err(); err != nil { |
| 108 | return nil, 0, err |
| 109 | } |
| 110 | |
| 111 | // Total count: paths + content rows that matched. Repos with |
| 112 | // visibility filter applied. Pagination is approximate when |
| 113 | // the same path matches both indexes — we count the union |
| 114 | // honestly so the pager doesn't lie. |
| 115 | countQuery := fmt.Sprintf(` |
| 116 | SELECT ( |
| 117 | SELECT count(*) FROM code_search_paths csp |
| 118 | JOIN repos r ON r.id = csp.repo_id |
| 119 | WHERE csp.tsv @@ %[1]s('shithub_search', $1) |
| 120 | AND %[2]s |
| 121 | %[3]s |
| 122 | ) + ( |
| 123 | SELECT count(*) FROM code_search_content csc |
| 124 | JOIN repos r ON r.id = csc.repo_id |
| 125 | WHERE csc.content_tsv @@ %[1]s('shithub_search', $1) |
| 126 | AND %[2]s |
| 127 | %[3]s |
| 128 | ) |
| 129 | `, tsCtor, visClause, repoFilter) |
| 130 | var total int64 |
| 131 | if err := deps.Pool.QueryRow(ctx, countQuery, args[:len(args)-2]...).Scan(&total); err != nil { |
| 132 | return nil, 0, fmt.Errorf("count code: %w", err) |
| 133 | } |
| 134 | return out, total, nil |
| 135 | } |
| 136 |