Go · 13401 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 package webhook
4
5 import (
6 "bytes"
7 "context"
8 "encoding/json"
9 "errors"
10 "fmt"
11 "io"
12 "log/slog"
13 "math/rand"
14 "net/http"
15 "time"
16
17 "github.com/jackc/pgx/v5/pgtype"
18 "github.com/jackc/pgx/v5/pgxpool"
19
20 "github.com/tenseleyFlow/shithub/internal/auth/secretbox"
21 webhookdb "github.com/tenseleyFlow/shithub/internal/webhook/sqlc"
22 )
23
24 // ResponseBodyCap bounds the response body we keep in DB so a malicious
25 // receiver can't blow up our storage. Above this we set
26 // response_truncated=true.
27 const ResponseBodyCap = 32 * 1024
28
29 // MaxBodySize caps the request body at 25 MiB per the spec. We don't
30 // build payloads anywhere near this, but the cap protects against
31 // pathological event shapes.
32 const MaxBodySize = 25 * 1024 * 1024
33
34 // DeliverDeps wires the deliverer.
35 type DeliverDeps struct {
36 Pool *pgxpool.Pool
37 Logger *slog.Logger
38 HTTPClient *http.Client
39 SecretBox *secretbox.Box
40 SSRF SSRFConfig
41 // JitterFn is plumbed for tests; nil => math/rand.Float64.
42 JitterFn func() float64
43 }
44
45 // Deliver handles one delivery row end-to-end: load + decrypt secret,
46 // validate URL via SSRF defense, POST, record outcome, schedule retry
47 // or mark terminal. Returns nil on success, an error to be logged on
48 // transient failure (the row is updated regardless).
49 func Deliver(ctx context.Context, deps DeliverDeps, deliveryID int64) error {
50 if deps.Pool == nil {
51 return errors.New("webhook deliver: nil Pool")
52 }
53 if deps.SecretBox == nil {
54 return errors.New("webhook deliver: nil SecretBox")
55 }
56 q := webhookdb.New()
57 row, err := q.GetDeliveryByID(ctx, deps.Pool, deliveryID)
58 if err != nil {
59 return fmt.Errorf("load delivery: %w", err)
60 }
61 // Already-terminal rows can land here when a redelivery is
62 // re-enqueued; nothing to do.
63 if row.Status == webhookdb.WebhookDeliveryStatusSucceeded ||
64 row.Status == webhookdb.WebhookDeliveryStatusFailedPermanent {
65 return nil
66 }
67 hook, err := q.GetWebhookByID(ctx, deps.Pool, row.WebhookID)
68 if err != nil {
69 return fmt.Errorf("load webhook: %w", err)
70 }
71 if !hook.Active || hook.DisabledAt.Valid {
72 // Owner disabled the hook between fanout and delivery — mark
73 // the delivery permanently failed so it doesn't loop.
74 _ = q.MarkDeliveryPermanentFailure(ctx, deps.Pool, webhookdb.MarkDeliveryPermanentFailureParams{
75 ID: deliveryID,
76 ResponseStatus: pgtype.Int4{Valid: false},
77 ResponseHeaders: nil,
78 ResponseBody: nil,
79 ResponseTruncated: false,
80 ErrorSummary: "webhook disabled before delivery",
81 })
82 return nil
83 }
84 secret, err := OpenSecret(deps.SecretBox, hook.SecretCiphertext, hook.SecretNonce)
85 if err != nil {
86 _ = q.MarkDeliveryPermanentFailure(ctx, deps.Pool, webhookdb.MarkDeliveryPermanentFailureParams{
87 ID: deliveryID,
88 ResponseStatus: pgtype.Int4{Valid: false},
89 ResponseHeaders: nil,
90 ResponseBody: nil,
91 ResponseTruncated: false,
92 ErrorSummary: "decrypt secret: " + err.Error(),
93 })
94 _ = q.AutoDisableWebhook(ctx, deps.Pool, webhookdb.AutoDisableWebhookParams{
95 ID: hook.ID,
96 DisabledReason: pgtype.Text{String: "secret decryption failure", Valid: true},
97 })
98 return nil
99 }
100
101 if err := deps.SSRF.Validate(hook.Url); err != nil {
102 _ = q.MarkDeliveryPermanentFailure(ctx, deps.Pool, webhookdb.MarkDeliveryPermanentFailureParams{
103 ID: deliveryID,
104 ResponseStatus: pgtype.Int4{Valid: false},
105 ResponseHeaders: nil,
106 ResponseBody: nil,
107 ResponseTruncated: false,
108 ErrorSummary: "ssrf: " + err.Error(),
109 })
110 recordFailure(ctx, q, deps, hook)
111 return nil
112 }
113
114 if len(row.RequestBody) > MaxBodySize {
115 _ = q.MarkDeliveryPermanentFailure(ctx, deps.Pool, webhookdb.MarkDeliveryPermanentFailureParams{
116 ID: deliveryID,
117 ResponseStatus: pgtype.Int4{Valid: false},
118 ResponseHeaders: nil,
119 ResponseBody: nil,
120 ResponseTruncated: false,
121 ErrorSummary: fmt.Sprintf("payload exceeds %d bytes", MaxBodySize),
122 })
123 return nil
124 }
125
126 httpClient := deps.HTTPClient
127 if httpClient == nil {
128 httpClient = deps.SSRF.HTTPClient()
129 }
130
131 req, err := http.NewRequestWithContext(ctx, http.MethodPost, hook.Url, bytes.NewReader(row.RequestBody))
132 if err != nil {
133 return scheduleRetryOrPermanent(ctx, q, deps, row, hook, nil, "build request: "+err.Error())
134 }
135 applyHeaders(req, row.RequestHeaders, row.DeliveryUuid.String(), secret)
136
137 resp, doErr := httpClient.Do(req)
138 if doErr != nil {
139 // Connection errors are retryable.
140 return scheduleRetryOrPermanent(ctx, q, deps, row, hook, nil, "transport: "+doErr.Error())
141 }
142 defer resp.Body.Close()
143
144 bodyBytes, truncated, _ := readCappedBody(resp.Body, ResponseBodyCap)
145 respHeaders, _ := json.Marshal(headerMap(resp.Header))
146
147 switch {
148 case resp.StatusCode >= 200 && resp.StatusCode < 400:
149 // 3xx is treated as success per spec — we don't follow
150 // redirects (the SSRF check covered the original URL only).
151 _ = q.MarkDeliverySucceeded(ctx, deps.Pool, webhookdb.MarkDeliverySucceededParams{
152 ID: deliveryID,
153 ResponseStatus: pgtype.Int4{Int32: int32(resp.StatusCode), Valid: true},
154 ResponseHeaders: respHeaders,
155 ResponseBody: bodyBytes,
156 ResponseTruncated: truncated,
157 })
158 _ = q.RecordWebhookSuccess(ctx, deps.Pool, hook.ID)
159 return nil
160 case isRetryableStatus(resp.StatusCode):
161 return scheduleRetryOrPermanentWithResp(ctx, q, deps, row, hook,
162 resp.StatusCode, respHeaders, bodyBytes, truncated,
163 fmt.Sprintf("HTTP %d (retryable)", resp.StatusCode))
164 default:
165 // Non-retryable 4xx (other than 408/429): terminal failure.
166 _ = q.MarkDeliveryPermanentFailure(ctx, deps.Pool, webhookdb.MarkDeliveryPermanentFailureParams{
167 ID: deliveryID,
168 ResponseStatus: pgtype.Int4{Int32: int32(resp.StatusCode), Valid: true},
169 ResponseHeaders: respHeaders,
170 ResponseBody: bodyBytes,
171 ResponseTruncated: truncated,
172 ErrorSummary: fmt.Sprintf("HTTP %d (non-retryable)", resp.StatusCode),
173 })
174 recordFailure(ctx, q, deps, hook)
175 return nil
176 }
177 }
178
179 // scheduleRetryOrPermanent decides whether `row` should be retried
180 // (next_retry_at = now + backoff(attempt+1)) or marked terminal
181 // (attempt+1 > max_attempts). Used for transport-level failures with
182 // no HTTP response.
183 func scheduleRetryOrPermanent(ctx context.Context, q *webhookdb.Queries, deps DeliverDeps, row webhookdb.WebhookDelivery, hook webhookdb.Webhook, respHeaders []byte, summary string) error {
184 return scheduleRetryOrPermanentWithResp(ctx, q, deps, row, hook, 0, respHeaders, nil, false, summary)
185 }
186
187 func scheduleRetryOrPermanentWithResp(ctx context.Context, q *webhookdb.Queries, deps DeliverDeps, row webhookdb.WebhookDelivery, hook webhookdb.Webhook, status int, respHeaders []byte, body []byte, truncated bool, summary string) error {
188 nextAttempt := row.Attempt + 1
189 jitter := deps.JitterFn
190 if jitter == nil {
191 jitter = rand.Float64
192 }
193 if int(nextAttempt) > int(row.MaxAttempts) {
194 statusVal := pgtype.Int4{Valid: false}
195 if status > 0 {
196 statusVal = pgtype.Int4{Int32: int32(status), Valid: true}
197 }
198 _ = q.MarkDeliveryPermanentFailure(ctx, deps.Pool, webhookdb.MarkDeliveryPermanentFailureParams{
199 ID: row.ID,
200 ResponseStatus: statusVal,
201 ResponseHeaders: respHeaders,
202 ResponseBody: body,
203 ResponseTruncated: truncated,
204 ErrorSummary: summary + " (max_attempts exhausted)",
205 })
206 recordFailure(ctx, q, deps, hook)
207 return nil
208 }
209 delay := Backoff(int(nextAttempt), jitter)
210 statusVal := pgtype.Int4{Valid: false}
211 if status > 0 {
212 statusVal = pgtype.Int4{Int32: int32(status), Valid: true}
213 }
214 _ = q.MarkDeliveryRetry(ctx, deps.Pool, webhookdb.MarkDeliveryRetryParams{
215 ID: row.ID,
216 ResponseStatus: statusVal,
217 ResponseHeaders: respHeaders,
218 ResponseBody: body,
219 ResponseTruncated: truncated,
220 NextRetryAt: pgtype.Timestamptz{Time: time.Now().Add(delay), Valid: true},
221 ErrorSummary: summary,
222 })
223 return nil
224 }
225
226 // recordFailure increments consecutive_failures and auto-disables when
227 // the threshold is crossed.
228 func recordFailure(ctx context.Context, q *webhookdb.Queries, deps DeliverDeps, hook webhookdb.Webhook) {
229 row, err := q.RecordWebhookFailure(ctx, deps.Pool, hook.ID)
230 if err != nil {
231 if deps.Logger != nil {
232 deps.Logger.WarnContext(ctx, "webhook deliver: record failure", "error", err)
233 }
234 return
235 }
236 if int(row.ConsecutiveFailures) >= int(row.AutoDisableThreshold) {
237 _ = q.AutoDisableWebhook(ctx, deps.Pool, webhookdb.AutoDisableWebhookParams{
238 ID: hook.ID,
239 DisabledReason: pgtype.Text{String: fmt.Sprintf("auto-disabled after %d consecutive failures", row.ConsecutiveFailures), Valid: true},
240 })
241 }
242 }
243
244 // applyHeaders writes the per-delivery headers + signature onto the
245 // outbound request. The HMAC is computed over the raw body so a
246 // receiver that re-serializes JSON will fail verification — that's the
247 // signal we want.
248 func applyHeaders(req *http.Request, headersJSON []byte, deliveryUUID, secret string) {
249 if len(headersJSON) > 0 {
250 var m map[string]string
251 if err := json.Unmarshal(headersJSON, &m); err == nil {
252 for k, v := range m {
253 req.Header.Set(k, v)
254 }
255 }
256 }
257 body, _ := io.ReadAll(req.Body)
258 req.Body = io.NopCloser(bytes.NewReader(body))
259 req.ContentLength = int64(len(body))
260 req.Header.Set("X-Shithub-Delivery", deliveryUUID)
261 req.Header.Set("X-Shithub-Signature-256", SignSHA256([]byte(secret), body))
262 }
263
264 // isRetryableStatus reports whether the deliverer should reschedule on
265 // this HTTP status. Per spec: 408 (Request Timeout), 429 (Too Many
266 // Requests), and any 5xx.
267 func isRetryableStatus(status int) bool {
268 if status == http.StatusRequestTimeout || status == http.StatusTooManyRequests {
269 return true
270 }
271 return status >= 500 && status < 600
272 }
273
274 // readCappedBody reads up to cap+1 bytes; if it overshoots, the extra
275 // is discarded and `truncated` is true.
276 func readCappedBody(r io.Reader, cap int) (body []byte, truncated bool, err error) {
277 limited := io.LimitReader(r, int64(cap)+1)
278 body, err = io.ReadAll(limited)
279 if err != nil {
280 return body, false, err
281 }
282 if len(body) > cap {
283 body = body[:cap]
284 truncated = true
285 // Drain the rest so the connection can be reused (we have
286 // keep-alives off, but politeness for any future change).
287 _, _ = io.Copy(io.Discard, r)
288 }
289 return body, truncated, nil
290 }
291
292 // headerMap turns http.Header into the {name: value} map we store. We
293 // take only the first value per key — webhooks are POSTs to JSON
294 // endpoints, multi-value headers are vanishingly rare.
295 func headerMap(h http.Header) map[string]string {
296 out := make(map[string]string, len(h))
297 for k, v := range h {
298 if len(v) > 0 {
299 out[k] = v[0]
300 }
301 }
302 return out
303 }
304
305 // EnqueuePing creates a synthetic ping delivery and enqueues it. Used
306 // on webhook create / re-enable so the operator sees an immediate
307 // round-trip.
308 func EnqueuePing(ctx context.Context, deps FanoutDeps, hookID int64) error {
309 q := webhookdb.New()
310 hook, err := q.GetWebhookByID(ctx, deps.Pool, hookID)
311 if err != nil {
312 return fmt.Errorf("load webhook: %w", err)
313 }
314 body, _ := json.Marshal(map[string]any{
315 "zen": "Approachable is better than simple.",
316 "hook": map[string]any{"id": hook.ID, "url": hook.Url},
317 })
318 hdrs, _ := json.Marshal(map[string]string{
319 "User-Agent": "shithub-Hookshot",
320 "Content-Type": contentTypeHeader(hook.ContentType),
321 "X-Shithub-Event": "ping",
322 })
323 row, err := q.CreateDelivery(ctx, deps.Pool, webhookdb.CreateDeliveryParams{
324 WebhookID: hookID,
325 EventKind: "ping",
326 EventID: pgtype.Int8{Valid: false},
327 Payload: body,
328 RequestHeaders: hdrs,
329 RequestBody: body,
330 Attempt: 1,
331 MaxAttempts: 3,
332 NextRetryAt: pgtype.Timestamptz{Time: time.Now(), Valid: true},
333 Status: webhookdb.WebhookDeliveryStatusPending,
334 IdempotencyKey: idempotencyKey(hookID, 0, body),
335 RedeliverOf: pgtype.Int8{Valid: false},
336 })
337 if err != nil {
338 return fmt.Errorf("create ping delivery: %w", err)
339 }
340 if _, err := workerEnqueueDeliver(ctx, deps.Pool, row.ID); err != nil {
341 return fmt.Errorf("enqueue ping: %w", err)
342 }
343 return nil
344 }
345
346 // Redeliver clones a past delivery into a new pending row so the UI's
347 // "redeliver" button preserves the audit trail (redeliver_of points
348 // at the original).
349 func Redeliver(ctx context.Context, deps FanoutDeps, originalID int64) (int64, error) {
350 q := webhookdb.New()
351 orig, err := q.GetDeliveryByID(ctx, deps.Pool, originalID)
352 if err != nil {
353 return 0, fmt.Errorf("load original: %w", err)
354 }
355 row, err := q.CreateDelivery(ctx, deps.Pool, webhookdb.CreateDeliveryParams{
356 WebhookID: orig.WebhookID,
357 EventKind: orig.EventKind,
358 EventID: orig.EventID,
359 Payload: orig.Payload,
360 RequestHeaders: orig.RequestHeaders,
361 RequestBody: orig.RequestBody,
362 Attempt: 1,
363 MaxAttempts: orig.MaxAttempts,
364 NextRetryAt: pgtype.Timestamptz{Time: time.Now(), Valid: true},
365 Status: webhookdb.WebhookDeliveryStatusPending,
366 IdempotencyKey: orig.IdempotencyKey,
367 RedeliverOf: pgtype.Int8{Int64: orig.ID, Valid: true},
368 })
369 if err != nil {
370 return 0, fmt.Errorf("create redelivery: %w", err)
371 }
372 if _, err := workerEnqueueDeliver(ctx, deps.Pool, row.ID); err != nil {
373 return row.ID, fmt.Errorf("enqueue redelivery: %w", err)
374 }
375 return row.ID, nil
376 }
377