tenseleyflow/shithub / 3dd5a86

Browse files

S33: deliverer — sign + POST + retry/backoff + auto-disable

Authored by espadonne
SHA
3dd5a868c074457df3f947ecfe225e8c39166c79
Parents
cabf7a3
Tree
e0a8a62

1 changed file

StatusFile+-
A internal/webhook/deliver.go 377 0
internal/webhook/deliver.goadded
@@ -0,0 +1,377 @@
1
+// SPDX-License-Identifier: AGPL-3.0-or-later
2
+
3
+package webhook
4
+
5
+import (
6
+	"bytes"
7
+	"context"
8
+	"encoding/json"
9
+	"errors"
10
+	"fmt"
11
+	"io"
12
+	"log/slog"
13
+	"math/rand"
14
+	"net/http"
15
+	"time"
16
+
17
+	"github.com/jackc/pgx/v5/pgtype"
18
+	"github.com/jackc/pgx/v5/pgxpool"
19
+
20
+	"github.com/tenseleyFlow/shithub/internal/auth/secretbox"
21
+	webhookdb "github.com/tenseleyFlow/shithub/internal/webhook/sqlc"
22
+)
23
+
24
+// ResponseBodyCap bounds the response body we keep in DB so a malicious
25
+// receiver can't blow up our storage. Above this we set
26
+// response_truncated=true.
27
+const ResponseBodyCap = 32 * 1024
28
+
29
+// MaxBodySize caps the request body at 25 MiB per the spec. We don't
30
+// build payloads anywhere near this, but the cap protects against
31
+// pathological event shapes.
32
+const MaxBodySize = 25 * 1024 * 1024
33
+
34
+// DeliverDeps wires the deliverer.
35
+type DeliverDeps struct {
36
+	Pool       *pgxpool.Pool
37
+	Logger     *slog.Logger
38
+	HTTPClient *http.Client
39
+	SecretBox  *secretbox.Box
40
+	SSRF       SSRFConfig
41
+	// JitterFn is plumbed for tests; nil => math/rand.Float64.
42
+	JitterFn func() float64
43
+}
44
+
45
+// Deliver handles one delivery row end-to-end: load + decrypt secret,
46
+// validate URL via SSRF defense, POST, record outcome, schedule retry
47
+// or mark terminal. Returns nil on success, an error to be logged on
48
+// transient failure (the row is updated regardless).
49
+func Deliver(ctx context.Context, deps DeliverDeps, deliveryID int64) error {
50
+	if deps.Pool == nil {
51
+		return errors.New("webhook deliver: nil Pool")
52
+	}
53
+	if deps.SecretBox == nil {
54
+		return errors.New("webhook deliver: nil SecretBox")
55
+	}
56
+	q := webhookdb.New()
57
+	row, err := q.GetDeliveryByID(ctx, deps.Pool, deliveryID)
58
+	if err != nil {
59
+		return fmt.Errorf("load delivery: %w", err)
60
+	}
61
+	// Already-terminal rows can land here when a redelivery is
62
+	// re-enqueued; nothing to do.
63
+	if row.Status == webhookdb.WebhookDeliveryStatusSucceeded ||
64
+		row.Status == webhookdb.WebhookDeliveryStatusFailedPermanent {
65
+		return nil
66
+	}
67
+	hook, err := q.GetWebhookByID(ctx, deps.Pool, row.WebhookID)
68
+	if err != nil {
69
+		return fmt.Errorf("load webhook: %w", err)
70
+	}
71
+	if !hook.Active || hook.DisabledAt.Valid {
72
+		// Owner disabled the hook between fanout and delivery — mark
73
+		// the delivery permanently failed so it doesn't loop.
74
+		_ = q.MarkDeliveryPermanentFailure(ctx, deps.Pool, webhookdb.MarkDeliveryPermanentFailureParams{
75
+			ID:                deliveryID,
76
+			ResponseStatus:    pgtype.Int4{Valid: false},
77
+			ResponseHeaders:   nil,
78
+			ResponseBody:      nil,
79
+			ResponseTruncated: false,
80
+			ErrorSummary:      "webhook disabled before delivery",
81
+		})
82
+		return nil
83
+	}
84
+	secret, err := OpenSecret(deps.SecretBox, hook.SecretCiphertext, hook.SecretNonce)
85
+	if err != nil {
86
+		_ = q.MarkDeliveryPermanentFailure(ctx, deps.Pool, webhookdb.MarkDeliveryPermanentFailureParams{
87
+			ID:                deliveryID,
88
+			ResponseStatus:    pgtype.Int4{Valid: false},
89
+			ResponseHeaders:   nil,
90
+			ResponseBody:      nil,
91
+			ResponseTruncated: false,
92
+			ErrorSummary:      "decrypt secret: " + err.Error(),
93
+		})
94
+		_ = q.AutoDisableWebhook(ctx, deps.Pool, webhookdb.AutoDisableWebhookParams{
95
+			ID:             hook.ID,
96
+			DisabledReason: pgtype.Text{String: "secret decryption failure", Valid: true},
97
+		})
98
+		return nil
99
+	}
100
+
101
+	if err := deps.SSRF.Validate(hook.Url); err != nil {
102
+		_ = q.MarkDeliveryPermanentFailure(ctx, deps.Pool, webhookdb.MarkDeliveryPermanentFailureParams{
103
+			ID:                deliveryID,
104
+			ResponseStatus:    pgtype.Int4{Valid: false},
105
+			ResponseHeaders:   nil,
106
+			ResponseBody:      nil,
107
+			ResponseTruncated: false,
108
+			ErrorSummary:      "ssrf: " + err.Error(),
109
+		})
110
+		recordFailure(ctx, q, deps, hook)
111
+		return nil
112
+	}
113
+
114
+	if len(row.RequestBody) > MaxBodySize {
115
+		_ = q.MarkDeliveryPermanentFailure(ctx, deps.Pool, webhookdb.MarkDeliveryPermanentFailureParams{
116
+			ID:                deliveryID,
117
+			ResponseStatus:    pgtype.Int4{Valid: false},
118
+			ResponseHeaders:   nil,
119
+			ResponseBody:      nil,
120
+			ResponseTruncated: false,
121
+			ErrorSummary:      fmt.Sprintf("payload exceeds %d bytes", MaxBodySize),
122
+		})
123
+		return nil
124
+	}
125
+
126
+	httpClient := deps.HTTPClient
127
+	if httpClient == nil {
128
+		httpClient = deps.SSRF.HTTPClient()
129
+	}
130
+
131
+	req, err := http.NewRequestWithContext(ctx, http.MethodPost, hook.Url, bytes.NewReader(row.RequestBody))
132
+	if err != nil {
133
+		return scheduleRetryOrPermanent(ctx, q, deps, row, hook, nil, "build request: "+err.Error())
134
+	}
135
+	applyHeaders(req, row.RequestHeaders, row.DeliveryUuid.String(), secret)
136
+
137
+	resp, doErr := httpClient.Do(req)
138
+	if doErr != nil {
139
+		// Connection errors are retryable.
140
+		return scheduleRetryOrPermanent(ctx, q, deps, row, hook, nil, "transport: "+doErr.Error())
141
+	}
142
+	defer resp.Body.Close()
143
+
144
+	bodyBytes, truncated, _ := readCappedBody(resp.Body, ResponseBodyCap)
145
+	respHeaders, _ := json.Marshal(headerMap(resp.Header))
146
+
147
+	switch {
148
+	case resp.StatusCode >= 200 && resp.StatusCode < 400:
149
+		// 3xx is treated as success per spec — we don't follow
150
+		// redirects (the SSRF check covered the original URL only).
151
+		_ = q.MarkDeliverySucceeded(ctx, deps.Pool, webhookdb.MarkDeliverySucceededParams{
152
+			ID:                deliveryID,
153
+			ResponseStatus:    pgtype.Int4{Int32: int32(resp.StatusCode), Valid: true},
154
+			ResponseHeaders:   respHeaders,
155
+			ResponseBody:      bodyBytes,
156
+			ResponseTruncated: truncated,
157
+		})
158
+		_ = q.RecordWebhookSuccess(ctx, deps.Pool, hook.ID)
159
+		return nil
160
+	case isRetryableStatus(resp.StatusCode):
161
+		return scheduleRetryOrPermanentWithResp(ctx, q, deps, row, hook,
162
+			resp.StatusCode, respHeaders, bodyBytes, truncated,
163
+			fmt.Sprintf("HTTP %d (retryable)", resp.StatusCode))
164
+	default:
165
+		// Non-retryable 4xx (other than 408/429): terminal failure.
166
+		_ = q.MarkDeliveryPermanentFailure(ctx, deps.Pool, webhookdb.MarkDeliveryPermanentFailureParams{
167
+			ID:                deliveryID,
168
+			ResponseStatus:    pgtype.Int4{Int32: int32(resp.StatusCode), Valid: true},
169
+			ResponseHeaders:   respHeaders,
170
+			ResponseBody:      bodyBytes,
171
+			ResponseTruncated: truncated,
172
+			ErrorSummary:      fmt.Sprintf("HTTP %d (non-retryable)", resp.StatusCode),
173
+		})
174
+		recordFailure(ctx, q, deps, hook)
175
+		return nil
176
+	}
177
+}
178
+
179
+// scheduleRetryOrPermanent decides whether `row` should be retried
180
+// (next_retry_at = now + backoff(attempt+1)) or marked terminal
181
+// (attempt+1 > max_attempts). Used for transport-level failures with
182
+// no HTTP response.
183
+func scheduleRetryOrPermanent(ctx context.Context, q *webhookdb.Queries, deps DeliverDeps, row webhookdb.WebhookDelivery, hook webhookdb.Webhook, respHeaders []byte, summary string) error {
184
+	return scheduleRetryOrPermanentWithResp(ctx, q, deps, row, hook, 0, respHeaders, nil, false, summary)
185
+}
186
+
187
+func scheduleRetryOrPermanentWithResp(ctx context.Context, q *webhookdb.Queries, deps DeliverDeps, row webhookdb.WebhookDelivery, hook webhookdb.Webhook, status int, respHeaders []byte, body []byte, truncated bool, summary string) error {
188
+	nextAttempt := row.Attempt + 1
189
+	jitter := deps.JitterFn
190
+	if jitter == nil {
191
+		jitter = rand.Float64
192
+	}
193
+	if int(nextAttempt) > int(row.MaxAttempts) {
194
+		statusVal := pgtype.Int4{Valid: false}
195
+		if status > 0 {
196
+			statusVal = pgtype.Int4{Int32: int32(status), Valid: true}
197
+		}
198
+		_ = q.MarkDeliveryPermanentFailure(ctx, deps.Pool, webhookdb.MarkDeliveryPermanentFailureParams{
199
+			ID:                row.ID,
200
+			ResponseStatus:    statusVal,
201
+			ResponseHeaders:   respHeaders,
202
+			ResponseBody:      body,
203
+			ResponseTruncated: truncated,
204
+			ErrorSummary:      summary + " (max_attempts exhausted)",
205
+		})
206
+		recordFailure(ctx, q, deps, hook)
207
+		return nil
208
+	}
209
+	delay := Backoff(int(nextAttempt), jitter)
210
+	statusVal := pgtype.Int4{Valid: false}
211
+	if status > 0 {
212
+		statusVal = pgtype.Int4{Int32: int32(status), Valid: true}
213
+	}
214
+	_ = q.MarkDeliveryRetry(ctx, deps.Pool, webhookdb.MarkDeliveryRetryParams{
215
+		ID:                row.ID,
216
+		ResponseStatus:    statusVal,
217
+		ResponseHeaders:   respHeaders,
218
+		ResponseBody:      body,
219
+		ResponseTruncated: truncated,
220
+		NextRetryAt:       pgtype.Timestamptz{Time: time.Now().Add(delay), Valid: true},
221
+		ErrorSummary:      summary,
222
+	})
223
+	return nil
224
+}
225
+
226
+// recordFailure increments consecutive_failures and auto-disables when
227
+// the threshold is crossed.
228
+func recordFailure(ctx context.Context, q *webhookdb.Queries, deps DeliverDeps, hook webhookdb.Webhook) {
229
+	row, err := q.RecordWebhookFailure(ctx, deps.Pool, hook.ID)
230
+	if err != nil {
231
+		if deps.Logger != nil {
232
+			deps.Logger.WarnContext(ctx, "webhook deliver: record failure", "error", err)
233
+		}
234
+		return
235
+	}
236
+	if int(row.ConsecutiveFailures) >= int(row.AutoDisableThreshold) {
237
+		_ = q.AutoDisableWebhook(ctx, deps.Pool, webhookdb.AutoDisableWebhookParams{
238
+			ID:             hook.ID,
239
+			DisabledReason: pgtype.Text{String: fmt.Sprintf("auto-disabled after %d consecutive failures", row.ConsecutiveFailures), Valid: true},
240
+		})
241
+	}
242
+}
243
+
244
+// applyHeaders writes the per-delivery headers + signature onto the
245
+// outbound request. The HMAC is computed over the raw body so a
246
+// receiver that re-serializes JSON will fail verification — that's the
247
+// signal we want.
248
+func applyHeaders(req *http.Request, headersJSON []byte, deliveryUUID, secret string) {
249
+	if len(headersJSON) > 0 {
250
+		var m map[string]string
251
+		if err := json.Unmarshal(headersJSON, &m); err == nil {
252
+			for k, v := range m {
253
+				req.Header.Set(k, v)
254
+			}
255
+		}
256
+	}
257
+	body, _ := io.ReadAll(req.Body)
258
+	req.Body = io.NopCloser(bytes.NewReader(body))
259
+	req.ContentLength = int64(len(body))
260
+	req.Header.Set("X-Shithub-Delivery", deliveryUUID)
261
+	req.Header.Set("X-Shithub-Signature-256", SignSHA256([]byte(secret), body))
262
+}
263
+
264
+// isRetryableStatus reports whether the deliverer should reschedule on
265
+// this HTTP status. Per spec: 408 (Request Timeout), 429 (Too Many
266
+// Requests), and any 5xx.
267
+func isRetryableStatus(status int) bool {
268
+	if status == http.StatusRequestTimeout || status == http.StatusTooManyRequests {
269
+		return true
270
+	}
271
+	return status >= 500 && status < 600
272
+}
273
+
274
+// readCappedBody reads up to cap+1 bytes; if it overshoots, the extra
275
+// is discarded and `truncated` is true.
276
+func readCappedBody(r io.Reader, cap int) (body []byte, truncated bool, err error) {
277
+	limited := io.LimitReader(r, int64(cap)+1)
278
+	body, err = io.ReadAll(limited)
279
+	if err != nil {
280
+		return body, false, err
281
+	}
282
+	if len(body) > cap {
283
+		body = body[:cap]
284
+		truncated = true
285
+		// Drain the rest so the connection can be reused (we have
286
+		// keep-alives off, but politeness for any future change).
287
+		_, _ = io.Copy(io.Discard, r)
288
+	}
289
+	return body, truncated, nil
290
+}
291
+
292
+// headerMap turns http.Header into the {name: value} map we store. We
293
+// take only the first value per key — webhooks are POSTs to JSON
294
+// endpoints, multi-value headers are vanishingly rare.
295
+func headerMap(h http.Header) map[string]string {
296
+	out := make(map[string]string, len(h))
297
+	for k, v := range h {
298
+		if len(v) > 0 {
299
+			out[k] = v[0]
300
+		}
301
+	}
302
+	return out
303
+}
304
+
305
+// EnqueuePing creates a synthetic ping delivery and enqueues it. Used
306
+// on webhook create / re-enable so the operator sees an immediate
307
+// round-trip.
308
+func EnqueuePing(ctx context.Context, deps FanoutDeps, hookID int64) error {
309
+	q := webhookdb.New()
310
+	hook, err := q.GetWebhookByID(ctx, deps.Pool, hookID)
311
+	if err != nil {
312
+		return fmt.Errorf("load webhook: %w", err)
313
+	}
314
+	body, _ := json.Marshal(map[string]any{
315
+		"zen":  "Approachable is better than simple.",
316
+		"hook": map[string]any{"id": hook.ID, "url": hook.Url},
317
+	})
318
+	hdrs, _ := json.Marshal(map[string]string{
319
+		"User-Agent":      "shithub-Hookshot",
320
+		"Content-Type":    contentTypeHeader(hook.ContentType),
321
+		"X-Shithub-Event": "ping",
322
+	})
323
+	row, err := q.CreateDelivery(ctx, deps.Pool, webhookdb.CreateDeliveryParams{
324
+		WebhookID:      hookID,
325
+		EventKind:      "ping",
326
+		EventID:        pgtype.Int8{Valid: false},
327
+		Payload:        body,
328
+		RequestHeaders: hdrs,
329
+		RequestBody:    body,
330
+		Attempt:        1,
331
+		MaxAttempts:    3,
332
+		NextRetryAt:    pgtype.Timestamptz{Time: time.Now(), Valid: true},
333
+		Status:         webhookdb.WebhookDeliveryStatusPending,
334
+		IdempotencyKey: idempotencyKey(hookID, 0, body),
335
+		RedeliverOf:    pgtype.Int8{Valid: false},
336
+	})
337
+	if err != nil {
338
+		return fmt.Errorf("create ping delivery: %w", err)
339
+	}
340
+	if _, err := workerEnqueueDeliver(ctx, deps.Pool, row.ID); err != nil {
341
+		return fmt.Errorf("enqueue ping: %w", err)
342
+	}
343
+	return nil
344
+}
345
+
346
+// Redeliver clones a past delivery into a new pending row so the UI's
347
+// "redeliver" button preserves the audit trail (redeliver_of points
348
+// at the original).
349
+func Redeliver(ctx context.Context, deps FanoutDeps, originalID int64) (int64, error) {
350
+	q := webhookdb.New()
351
+	orig, err := q.GetDeliveryByID(ctx, deps.Pool, originalID)
352
+	if err != nil {
353
+		return 0, fmt.Errorf("load original: %w", err)
354
+	}
355
+	row, err := q.CreateDelivery(ctx, deps.Pool, webhookdb.CreateDeliveryParams{
356
+		WebhookID:      orig.WebhookID,
357
+		EventKind:      orig.EventKind,
358
+		EventID:        orig.EventID,
359
+		Payload:        orig.Payload,
360
+		RequestHeaders: orig.RequestHeaders,
361
+		RequestBody:    orig.RequestBody,
362
+		Attempt:        1,
363
+		MaxAttempts:    orig.MaxAttempts,
364
+		NextRetryAt:    pgtype.Timestamptz{Time: time.Now(), Valid: true},
365
+		Status:         webhookdb.WebhookDeliveryStatusPending,
366
+		IdempotencyKey: orig.IdempotencyKey,
367
+		RedeliverOf:    pgtype.Int8{Int64: orig.ID, Valid: true},
368
+	})
369
+	if err != nil {
370
+		return 0, fmt.Errorf("create redelivery: %w", err)
371
+	}
372
+	if _, err := workerEnqueueDeliver(ctx, deps.Pool, row.ID); err != nil {
373
+		return row.ID, fmt.Errorf("enqueue redelivery: %w", err)
374
+	}
375
+	return row.ID, nil
376
+}
377
+