| 1 | // SPDX-License-Identifier: AGPL-3.0-or-later |
| 2 | |
| 3 | package api |
| 4 | |
| 5 | import ( |
| 6 | "encoding/json" |
| 7 | "errors" |
| 8 | "net/http" |
| 9 | "net/url" |
| 10 | "strconv" |
| 11 | "time" |
| 12 | |
| 13 | "github.com/go-chi/chi/v5" |
| 14 | "github.com/jackc/pgx/v5" |
| 15 | "github.com/jackc/pgx/v5/pgconn" |
| 16 | |
| 17 | "github.com/tenseleyFlow/shithub/internal/auth/pat" |
| 18 | "github.com/tenseleyFlow/shithub/internal/auth/sshkey" |
| 19 | usersdb "github.com/tenseleyFlow/shithub/internal/users/sqlc" |
| 20 | "github.com/tenseleyFlow/shithub/internal/web/handlers/api/apipage" |
| 21 | "github.com/tenseleyFlow/shithub/internal/web/middleware" |
| 22 | ) |
| 23 | |
| 24 | // mountUserKeys registers the SSH-keys REST surface for the authenticated |
| 25 | // user. Shape mirrors GitHub's /user/keys (authentication keys only; |
| 26 | // signing keys land at /user/ssh_signing_keys in a future batch). |
| 27 | // |
| 28 | // GET /api/v1/user/keys list (paginated) |
| 29 | // POST /api/v1/user/keys add { title, key } |
| 30 | // GET /api/v1/user/keys/{id} get one |
| 31 | // DELETE /api/v1/user/keys/{id} remove |
| 32 | // |
| 33 | // Scopes: user:read for GETs, user:write for POST/DELETE. |
| 34 | func (h *Handlers) mountUserKeys(r chi.Router) { |
| 35 | r.Group(func(r chi.Router) { |
| 36 | r.Use(middleware.RequireScope(pat.ScopeUserRead)) |
| 37 | r.Get("/api/v1/user/keys", h.userKeysList) |
| 38 | r.Get("/api/v1/user/keys/{id}", h.userKeyGet) |
| 39 | }) |
| 40 | r.Group(func(r chi.Router) { |
| 41 | r.Use(middleware.RequireScope(pat.ScopeUserWrite)) |
| 42 | r.Post("/api/v1/user/keys", h.userKeyCreate) |
| 43 | r.Delete("/api/v1/user/keys/{id}", h.userKeyDelete) |
| 44 | }) |
| 45 | } |
| 46 | |
| 47 | type userKeyResponse struct { |
| 48 | ID int64 `json:"id"` |
| 49 | Title string `json:"title"` |
| 50 | Key string `json:"key"` |
| 51 | Fingerprint string `json:"fingerprint"` |
| 52 | KeyType string `json:"key_type"` |
| 53 | Verified bool `json:"verified"` |
| 54 | ReadOnly bool `json:"read_only"` |
| 55 | CreatedAt string `json:"created_at"` |
| 56 | } |
| 57 | |
| 58 | func presentUserKey(k usersdb.UserSshKey) userKeyResponse { |
| 59 | return userKeyResponse{ |
| 60 | ID: k.ID, |
| 61 | Title: k.Title, |
| 62 | Key: k.PublicKey, |
| 63 | Fingerprint: "SHA256:" + k.FingerprintSha256, |
| 64 | KeyType: k.KeyType, |
| 65 | // Every key shithub stores has been parsed and validated at |
| 66 | // upload time. Surface as verified=true so gh-shaped clients |
| 67 | // that key off this field continue to work. |
| 68 | Verified: true, |
| 69 | ReadOnly: false, |
| 70 | CreatedAt: k.CreatedAt.Time.UTC().Format(time.RFC3339), |
| 71 | } |
| 72 | } |
| 73 | |
| 74 | func (h *Handlers) userKeysList(w http.ResponseWriter, r *http.Request) { |
| 75 | auth := middleware.PATAuthFromContext(r.Context()) |
| 76 | if auth.UserID == 0 { |
| 77 | writeAPIError(w, http.StatusUnauthorized, "unauthenticated") |
| 78 | return |
| 79 | } |
| 80 | page, perPage := apipage.ParseQuery(r, apipage.DefaultPerPage, apipage.MaxPerPage) |
| 81 | total, err := h.q.CountUserSSHKeysByKind(r.Context(), h.d.Pool, usersdb.CountUserSSHKeysByKindParams{ |
| 82 | UserID: auth.UserID, Kind: "authentication", |
| 83 | }) |
| 84 | if err != nil { |
| 85 | h.d.Logger.ErrorContext(r.Context(), "api: count user keys", "error", err) |
| 86 | writeAPIError(w, http.StatusInternalServerError, "list failed") |
| 87 | return |
| 88 | } |
| 89 | rows, err := h.q.ListUserSSHKeysByKind(r.Context(), h.d.Pool, usersdb.ListUserSSHKeysByKindParams{ |
| 90 | UserID: auth.UserID, |
| 91 | Kind: "authentication", |
| 92 | Limit: int32(perPage), |
| 93 | Offset: int32((page - 1) * perPage), |
| 94 | }) |
| 95 | if err != nil { |
| 96 | h.d.Logger.ErrorContext(r.Context(), "api: list user keys", "error", err) |
| 97 | writeAPIError(w, http.StatusInternalServerError, "list failed") |
| 98 | return |
| 99 | } |
| 100 | link := apipage.Page{ |
| 101 | Current: page, PerPage: perPage, Total: int(total), |
| 102 | }.LinkHeader(h.d.BaseURL, sanitizedURL(r)) |
| 103 | if link != "" { |
| 104 | w.Header().Set("Link", link) |
| 105 | } |
| 106 | out := make([]userKeyResponse, 0, len(rows)) |
| 107 | for _, k := range rows { |
| 108 | out = append(out, presentUserKey(k)) |
| 109 | } |
| 110 | writeJSON(w, http.StatusOK, out) |
| 111 | } |
| 112 | |
| 113 | func (h *Handlers) userKeyGet(w http.ResponseWriter, r *http.Request) { |
| 114 | auth := middleware.PATAuthFromContext(r.Context()) |
| 115 | if auth.UserID == 0 { |
| 116 | writeAPIError(w, http.StatusUnauthorized, "unauthenticated") |
| 117 | return |
| 118 | } |
| 119 | id, err := strconv.ParseInt(chi.URLParam(r, "id"), 10, 64) |
| 120 | if err != nil { |
| 121 | writeAPIError(w, http.StatusNotFound, "key not found") |
| 122 | return |
| 123 | } |
| 124 | k, err := h.q.GetUserSSHKey(r.Context(), h.d.Pool, usersdb.GetUserSSHKeyParams{ |
| 125 | ID: id, UserID: auth.UserID, |
| 126 | }) |
| 127 | if err != nil { |
| 128 | if errors.Is(err, pgx.ErrNoRows) { |
| 129 | writeAPIError(w, http.StatusNotFound, "key not found") |
| 130 | return |
| 131 | } |
| 132 | h.d.Logger.ErrorContext(r.Context(), "api: get user key", "error", err) |
| 133 | writeAPIError(w, http.StatusInternalServerError, "lookup failed") |
| 134 | return |
| 135 | } |
| 136 | if k.Kind != "authentication" { |
| 137 | // Same shape as a non-existent key: the signing-keys surface |
| 138 | // has its own route that this endpoint deliberately doesn't |
| 139 | // expose. |
| 140 | writeAPIError(w, http.StatusNotFound, "key not found") |
| 141 | return |
| 142 | } |
| 143 | writeJSON(w, http.StatusOK, presentUserKey(k)) |
| 144 | } |
| 145 | |
| 146 | type userKeyCreateRequest struct { |
| 147 | Title string `json:"title"` |
| 148 | Key string `json:"key"` |
| 149 | } |
| 150 | |
| 151 | func (h *Handlers) userKeyCreate(w http.ResponseWriter, r *http.Request) { |
| 152 | auth := middleware.PATAuthFromContext(r.Context()) |
| 153 | if auth.UserID == 0 { |
| 154 | writeAPIError(w, http.StatusUnauthorized, "unauthenticated") |
| 155 | return |
| 156 | } |
| 157 | var body userKeyCreateRequest |
| 158 | if err := json.NewDecoder(r.Body).Decode(&body); err != nil { |
| 159 | writeAPIError(w, http.StatusBadRequest, "invalid JSON: "+err.Error()) |
| 160 | return |
| 161 | } |
| 162 | parsed, err := sshkey.Parse(body.Title, body.Key) |
| 163 | if err != nil { |
| 164 | writeAPIError(w, http.StatusUnprocessableEntity, sshKeyAPIErrorMessage(err)) |
| 165 | return |
| 166 | } |
| 167 | count, err := h.q.CountUserSSHKeys(r.Context(), h.d.Pool, auth.UserID) |
| 168 | if err != nil { |
| 169 | h.d.Logger.ErrorContext(r.Context(), "api: count user keys", "error", err) |
| 170 | writeAPIError(w, http.StatusInternalServerError, "create failed") |
| 171 | return |
| 172 | } |
| 173 | if count >= int64(sshkey.MaxKeysPerUser) { |
| 174 | writeAPIError(w, http.StatusUnprocessableEntity, "per-user SSH-key cap reached") |
| 175 | return |
| 176 | } |
| 177 | k, err := h.q.InsertUserSSHKey(r.Context(), h.d.Pool, usersdb.InsertUserSSHKeyParams{ |
| 178 | UserID: auth.UserID, |
| 179 | Title: parsed.Title, |
| 180 | FingerprintSha256: parsed.Fingerprint, |
| 181 | KeyType: parsed.Type, |
| 182 | KeyBits: int32(parsed.Bits), //nolint:gosec // RSA-bit ceiling is bounded by sshkey.Parse. |
| 183 | PublicKey: parsed.PublicKey, |
| 184 | Kind: "authentication", |
| 185 | }) |
| 186 | if err != nil { |
| 187 | var pgErr *pgconn.PgError |
| 188 | if errors.As(err, &pgErr) && pgErr.Code == "23505" { |
| 189 | writeAPIError(w, http.StatusUnprocessableEntity, "key already registered") |
| 190 | return |
| 191 | } |
| 192 | h.d.Logger.ErrorContext(r.Context(), "api: insert user key", "error", err) |
| 193 | writeAPIError(w, http.StatusInternalServerError, "create failed") |
| 194 | return |
| 195 | } |
| 196 | writeJSON(w, http.StatusCreated, presentUserKey(k)) |
| 197 | } |
| 198 | |
| 199 | func (h *Handlers) userKeyDelete(w http.ResponseWriter, r *http.Request) { |
| 200 | auth := middleware.PATAuthFromContext(r.Context()) |
| 201 | if auth.UserID == 0 { |
| 202 | writeAPIError(w, http.StatusUnauthorized, "unauthenticated") |
| 203 | return |
| 204 | } |
| 205 | id, err := strconv.ParseInt(chi.URLParam(r, "id"), 10, 64) |
| 206 | if err != nil { |
| 207 | writeAPIError(w, http.StatusNotFound, "key not found") |
| 208 | return |
| 209 | } |
| 210 | rows, err := h.q.DeleteUserSSHKey(r.Context(), h.d.Pool, usersdb.DeleteUserSSHKeyParams{ |
| 211 | ID: id, UserID: auth.UserID, |
| 212 | }) |
| 213 | if err != nil { |
| 214 | h.d.Logger.ErrorContext(r.Context(), "api: delete user key", "error", err) |
| 215 | writeAPIError(w, http.StatusInternalServerError, "delete failed") |
| 216 | return |
| 217 | } |
| 218 | if rows == 0 { |
| 219 | writeAPIError(w, http.StatusNotFound, "key not found") |
| 220 | return |
| 221 | } |
| 222 | w.WriteHeader(http.StatusNoContent) |
| 223 | } |
| 224 | |
| 225 | // sshKeyAPIErrorMessage maps the typed parser errors to user-facing |
| 226 | // strings appropriate for an API client (no UI verbiage). |
| 227 | func sshKeyAPIErrorMessage(err error) string { |
| 228 | switch { |
| 229 | case errors.Is(err, sshkey.ErrTitleEmpty): |
| 230 | return "title is required" |
| 231 | case errors.Is(err, sshkey.ErrTitleTooLong): |
| 232 | return "title must be at most 80 characters" |
| 233 | case errors.Is(err, sshkey.ErrTitleControl): |
| 234 | return "title contains control characters" |
| 235 | case errors.Is(err, sshkey.ErrUnsupportedAlgo): |
| 236 | return "unsupported key algorithm (use ed25519, ECDSA, or RSA >= 2048 bits)" |
| 237 | case errors.Is(err, sshkey.ErrRSATooShort): |
| 238 | return "RSA keys must be at least 2048 bits" |
| 239 | case errors.Is(err, sshkey.ErrUnparseable): |
| 240 | return "could not parse key blob" |
| 241 | default: |
| 242 | return "invalid key" |
| 243 | } |
| 244 | } |
| 245 | |
| 246 | // sanitizedURL returns a copy of the request URL with no scheme/host, |
| 247 | // suitable for feeding into apipage.LinkHeader without leaking proxy |
| 248 | // internals. The helper exists so handlers don't have to repeat the |
| 249 | // boilerplate copy + clear. |
| 250 | func sanitizedURL(r *http.Request) *url.URL { |
| 251 | u := *r.URL |
| 252 | u.Scheme = "" |
| 253 | u.Host = "" |
| 254 | return &u |
| 255 | } |
| 256 |