Go · 2071 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 package db
4
5 import (
6 "context"
7 "sync/atomic"
8
9 "github.com/jackc/pgx/v5"
10 )
11
12 // QueryCounter is a pgx QueryTracer that increments a per-context
13 // counter on every Query/QueryRow/Exec. The counter lives on the
14 // request context, so the tracer is safe to install at pool-config
15 // time even when most requests don't care.
16 //
17 // Use case: handler integration tests assert "this route does ≤ N
18 // DB queries." The tracer + WithCounter / FromContext + a tiny
19 // middleware (web/middleware/query_count_assert.go) make that
20 // assertion a one-liner.
21 //
22 // Production runs leave the tracer installed but never call
23 // WithCounter; the per-conn overhead is one atomic-load per query
24 // (the tracer reads the context value but the counter is nil so
25 // no Add fires).
26 type QueryCounter struct{}
27
28 type queryCounterKey struct{}
29
30 // counter is the per-context Adder. Atomic so concurrent Query
31 // invocations on the same request context (rare but possible —
32 // goroutines per row) don't undercount.
33 type counter struct {
34 n atomic.Int64
35 }
36
37 // WithCounter returns a derived context that records query counts.
38 // Pair with Read.
39 func WithCounter(ctx context.Context) context.Context {
40 return context.WithValue(ctx, queryCounterKey{}, &counter{})
41 }
42
43 // FromContext reports how many tracer events have fired against ctx.
44 // Returns 0 when WithCounter wasn't called on this context.
45 func FromContext(ctx context.Context) int64 {
46 c, ok := ctx.Value(queryCounterKey{}).(*counter)
47 if !ok {
48 return 0
49 }
50 return c.n.Load()
51 }
52
53 // TraceQueryStart implements pgx.QueryTracer. The counter increments
54 // at start, not end, so a slow query still counts.
55 func (QueryCounter) TraceQueryStart(ctx context.Context, _ *pgx.Conn, _ pgx.TraceQueryStartData) context.Context {
56 if c, ok := ctx.Value(queryCounterKey{}).(*counter); ok {
57 c.n.Add(1)
58 }
59 return ctx
60 }
61
62 // TraceQueryEnd implements pgx.QueryTracer. No-op; the start tick is
63 // the meaningful event.
64 func (QueryCounter) TraceQueryEnd(context.Context, *pgx.Conn, pgx.TraceQueryEndData) {}
65