Go · 8576 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 package main
4
5 import (
6 "context"
7 "fmt"
8 "log/slog"
9 "net/netip"
10 "os"
11 "os/exec"
12 "path/filepath"
13 "strconv"
14 "strings"
15 "syscall"
16 "time"
17
18 "github.com/spf13/cobra"
19
20 "github.com/tenseleyFlow/shithub/internal/git/protocol"
21 "github.com/tenseleyFlow/shithub/internal/infra/config"
22 "github.com/tenseleyFlow/shithub/internal/infra/db"
23 "github.com/tenseleyFlow/shithub/internal/infra/storage"
24 usersdb "github.com/tenseleyFlow/shithub/internal/users/sqlc"
25 )
26
27 // sshAuthkeysCmd implements sshd's AuthorizedKeysCommand contract:
28 //
29 // - On a known fingerprint, write a single authorized_keys line on stdout
30 // with a forced command and restrictive options.
31 // - On an unknown fingerprint OR any error, write nothing and exit 0.
32 // sshd uses STDOUT as the auth answer; non-zero exit is a config error,
33 // not a deny. Failing closed is the right model: better to deny a
34 // legitimate connection than accidentally authorize the wrong user.
35 //
36 // Latency is critical — every SSH connection waits on this. The pool is
37 // sized small (max 4 conns) to bound startup cost and tail-latency.
38 var sshAuthkeysCmd = &cobra.Command{
39 Use: "ssh-authkeys <fingerprint>",
40 Short: "AuthorizedKeysCommand handler for sshd",
41 Args: cobra.ExactArgs(1),
42 Hidden: true, // not for direct human use
43 RunE: func(cmd *cobra.Command, args []string) error {
44 // Fail-closed wrapper: anything below that returns an error or
45 // panics writes nothing to stdout. The exit code stays 0.
46 defer func() {
47 _ = recover()
48 }()
49 fp := strings.TrimSpace(args[0])
50 if !isWellFormedFingerprint(fp) {
51 return nil
52 }
53
54 cfg, err := config.Load(nil)
55 if err != nil || cfg.DB.URL == "" {
56 return nil
57 }
58
59 ctx, cancel := context.WithTimeout(cmd.Context(), 1500*time.Millisecond)
60 defer cancel()
61
62 pool, err := db.Open(ctx, db.Config{
63 URL: cfg.DB.URL, MaxConns: 4, MinConns: 0,
64 ConnectTimeout: 750 * time.Millisecond,
65 })
66 if err != nil {
67 return nil
68 }
69 defer pool.Close()
70
71 q := usersdb.New()
72 row, err := q.GetUserSSHKeyByFingerprint(ctx, pool, fp)
73 if err != nil {
74 // pgx.ErrNoRows or any other error → silently empty.
75 return nil
76 }
77
78 _, _ = fmt.Fprintln(cmd.OutOrStdout(), authorizedKeysLine(row))
79
80 // Best-effort last-used update. 500ms cap; any error is dropped.
81 updateCtx, updateCancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
82 defer updateCancel()
83 _ = q.TouchSSHKeyLastUsed(updateCtx, pool, usersdb.TouchSSHKeyLastUsedParams{
84 ID: row.ID,
85 LastUsedIp: clientAddrFromEnv(),
86 })
87 return nil
88 },
89 }
90
91 // sshShellCmd is the forced-command target sshd invokes after the
92 // AuthorizedKeysCommand handshake binds the connection to a user.
93 //
94 // Flow on a successful clone/push:
95 //
96 // sshd ──► shithubd ssh-shell <user_id>
97 // ├─ ParseSSHCommand(SSH_ORIGINAL_COMMAND)
98 // ├─ Resolve user + repo against the DB
99 // ├─ Inline owner-only authz (S15 will refactor)
100 // ├─ Build SHITHUB_* env (so post-receive hooks identify the actor)
101 // ├─ Close the DB pool (syscall.Exec preserves all open FDs)
102 // └─ syscall.Exec git-{upload,receive}-pack <bare-repo>
103 //
104 // On any error: write a friendly line to stderr (the user sees it in
105 // their git client), log structured, exit non-zero. defer does NOT
106 // fire on syscall.Exec — every cleanup happens BEFORE the exec call.
107 var sshShellCmd = &cobra.Command{
108 Use: "ssh-shell <user_id>",
109 Short: "Forced-command target invoked by sshd via AuthorizedKeysCommand",
110 Args: cobra.ExactArgs(1),
111 Hidden: true,
112 RunE: func(cmd *cobra.Command, args []string) error {
113 userID, err := strconv.ParseInt(args[0], 10, 64)
114 if err != nil {
115 _, _ = fmt.Fprintln(cmd.ErrOrStderr(), "shithub: invalid user")
116 return fmt.Errorf("ssh-shell: bad user_id %q: %w", args[0], err)
117 }
118 original := os.Getenv("SSH_ORIGINAL_COMMAND")
119 remoteIP := protocol.ParseRemoteIP(os.Getenv("SSH_CONNECTION"))
120 logger := slog.New(slog.NewTextHandler(cmd.ErrOrStderr(), &slog.HandlerOptions{Level: slog.LevelInfo}))
121
122 cfg, err := config.Load(nil)
123 if err != nil || cfg.DB.URL == "" || cfg.Storage.ReposRoot == "" {
124 _, _ = fmt.Fprintln(cmd.ErrOrStderr(), "shithub: server misconfigured")
125 return fmt.Errorf("ssh-shell: cfg: %w", err)
126 }
127 root, err := filepath.Abs(cfg.Storage.ReposRoot)
128 if err != nil {
129 _, _ = fmt.Fprintln(cmd.ErrOrStderr(), "shithub: server misconfigured")
130 return fmt.Errorf("ssh-shell: repos_root: %w", err)
131 }
132 rfs, err := storage.NewRepoFS(root)
133 if err != nil {
134 _, _ = fmt.Fprintln(cmd.ErrOrStderr(), "shithub: server misconfigured")
135 return fmt.Errorf("ssh-shell: NewRepoFS: %w", err)
136 }
137
138 ctx, cancel := context.WithTimeout(cmd.Context(), 5*time.Second)
139 defer cancel()
140 pool, err := db.Open(ctx, db.Config{
141 URL: cfg.DB.URL, MaxConns: 2, MinConns: 0,
142 ConnectTimeout: 1500 * time.Millisecond,
143 })
144 if err != nil {
145 _, _ = fmt.Fprintln(cmd.ErrOrStderr(), "shithub: temporary failure (try again)")
146 return fmt.Errorf("ssh-shell: db open: %w", err)
147 }
148
149 res, parsed, dispatchErr := protocol.PrepareDispatch(ctx, protocol.SSHDispatchDeps{
150 Pool: pool, RepoFS: rfs,
151 }, protocol.SSHDispatchInput{
152 OriginalCommand: original,
153 UserID: userID,
154 RemoteIP: remoteIP,
155 })
156 if dispatchErr != nil {
157 pool.Close()
158 _, _ = fmt.Fprintln(cmd.ErrOrStderr(), protocol.FriendlyMessageFor(dispatchErr, ""))
159 logger.WarnContext(ctx, "ssh-shell: denied",
160 "user_id", userID,
161 "original", original,
162 "remote_ip", remoteIP,
163 "error", dispatchErr,
164 )
165 return dispatchErr
166 }
167 logger.InfoContext(ctx, "ssh-shell: dispatch",
168 "user_id", userID,
169 "op", string(parsed.Service),
170 "owner", parsed.Owner,
171 "repo", parsed.Repo,
172 "remote_ip", remoteIP,
173 )
174
175 // CRITICAL: close DB pool before syscall.Exec. defer doesn't
176 // fire on exec, and the pgx pool's connections would otherwise
177 // leak into the new process's FD table.
178 pool.Close()
179
180 bin, err := exec.LookPath(res.Argv0)
181 if err != nil {
182 _, _ = fmt.Fprintln(cmd.ErrOrStderr(), "shithub: server misconfigured")
183 return fmt.Errorf("ssh-shell: lookup %s: %w", res.Argv0, err)
184 }
185 if err := sysExec(bin, res.Argv0Args, res.Env); err != nil {
186 _, _ = fmt.Fprintln(cmd.ErrOrStderr(), "shithub: internal error")
187 return fmt.Errorf("ssh-shell: exec %s: %w", bin, err)
188 }
189 // Unreachable on success — syscall.Exec replaces this process.
190 return nil
191 },
192 }
193
194 // sysExec is split out so tests can stub it. bin is exec.LookPath of a
195 // fixed service name (git-{upload,receive}-pack); argv[1] is the
196 // sanitized bare-repo path from storage.RepoFS.
197 //
198 //nolint:gosec // G204: inputs are constrained as documented above.
199 var sysExec = syscall.Exec
200
201 // authorizedKeysLine builds the single line sshd consumes. The forced
202 // command runs `shithubd ssh-shell <user_id>`; the option set strips
203 // every interactive affordance.
204 func authorizedKeysLine(row usersdb.UserSshKey) string {
205 binary := os.Args[0]
206 // Quote-escape only the binary path; user_id is a digit string so it
207 // can never contain shell metacharacters.
208 cmd := fmt.Sprintf(`%s ssh-shell %d`, binary, row.UserID)
209 options := strings.Join([]string{
210 fmt.Sprintf(`command="%s"`, cmd),
211 "no-port-forwarding",
212 "no-X11-forwarding",
213 "no-agent-forwarding",
214 "no-pty",
215 }, ",")
216 return options + " " + row.PublicKey
217 }
218
219 // clientAddrFromEnv extracts the connecting client's address from
220 // $SSH_CONNECTION (sshd sets it to "<client> <cport> <server> <sport>").
221 // Returns nil when unavailable, which sqlc encodes as a SQL NULL.
222 func clientAddrFromEnv() *netip.Addr {
223 conn := os.Getenv("SSH_CONNECTION")
224 if conn == "" {
225 return nil
226 }
227 parts := strings.Fields(conn)
228 if len(parts) < 1 {
229 return nil
230 }
231 addr, err := netip.ParseAddr(parts[0])
232 if err != nil {
233 return nil
234 }
235 return &addr
236 }
237
238 // isWellFormedFingerprint accepts only the canonical SHA256:<b64> shape
239 // our codebase emits. Defense against an attacker passing crafted strings
240 // to influence the SQL plan.
241 func isWellFormedFingerprint(s string) bool {
242 if !strings.HasPrefix(s, "SHA256:") {
243 return false
244 }
245 rest := s[len("SHA256:"):]
246 if len(rest) < 30 || len(rest) > 80 {
247 return false
248 }
249 for _, r := range rest {
250 switch {
251 case r >= 'A' && r <= 'Z',
252 r >= 'a' && r <= 'z',
253 r >= '0' && r <= '9',
254 r == '+', r == '/', r == '=':
255 default:
256 return false
257 }
258 }
259 return true
260 }
261
262 func init() {
263 rootCmd.AddCommand(sshAuthkeysCmd)
264 rootCmd.AddCommand(sshShellCmd)
265 }
266