Go · 4146 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 // Package db owns the Postgres connection lifecycle. S01 ships the
4 // open/healthcheck/transaction helpers; later sprints add domain-specific
5 // wrappers but always go through this package for the pool.
6 package db
7
8 import (
9 "context"
10 "errors"
11 "fmt"
12 "os"
13 "time"
14
15 "github.com/jackc/pgx/v5"
16 "github.com/jackc/pgx/v5/pgxpool"
17 )
18
19 // Config carries the pool's connection settings. S03 will populate this from
20 // the layered config loader; for S01 we accept an optional explicit URL or
21 // fall back to SHITHUB_DATABASE_URL.
22 type Config struct {
23 URL string
24 MaxConns int32
25 MinConns int32
26 ConnectTimeout time.Duration
27 StatementCancel time.Duration
28 }
29
30 // Defaults returns sensible defaults for a dev pool. Prod values land via
31 // the config loader in S03.
32 func Defaults() Config {
33 return Config{
34 MaxConns: 10,
35 MinConns: 1,
36 ConnectTimeout: 5 * time.Second,
37 StatementCancel: 30 * time.Second,
38 }
39 }
40
41 // Resolve fills in URL from env if missing and clamps numeric defaults.
42 func (c Config) Resolve() Config {
43 if c.URL == "" {
44 c.URL = os.Getenv("SHITHUB_DATABASE_URL")
45 }
46 if c.MaxConns <= 0 {
47 c.MaxConns = 10
48 }
49 if c.MinConns < 0 {
50 c.MinConns = 0
51 }
52 if c.ConnectTimeout <= 0 {
53 c.ConnectTimeout = 5 * time.Second
54 }
55 if c.StatementCancel <= 0 {
56 c.StatementCancel = 30 * time.Second
57 }
58 return c
59 }
60
61 // ErrNoURL is returned by Open when neither cfg.URL nor SHITHUB_DATABASE_URL
62 // is set.
63 var ErrNoURL = errors.New("db: no DATABASE_URL configured (set SHITHUB_DATABASE_URL)")
64
65 // Open creates a new pgx pool from cfg. The caller owns the pool's lifecycle
66 // and must call pool.Close() on shutdown.
67 func Open(ctx context.Context, cfg Config) (*pgxpool.Pool, error) {
68 cfg = cfg.Resolve()
69 if cfg.URL == "" {
70 return nil, ErrNoURL
71 }
72
73 pcfg, err := pgxpool.ParseConfig(cfg.URL)
74 if err != nil {
75 return nil, fmt.Errorf("db: parse config: %w", err)
76 }
77 pcfg.MaxConns = cfg.MaxConns
78 pcfg.MinConns = cfg.MinConns
79 pcfg.ConnConfig.ConnectTimeout = cfg.ConnectTimeout
80 // QueryCounter is a no-op when the request context wasn't built
81 // with WithCounter — production traffic pays one map lookup per
82 // query. Tests that assert "this route does ≤ N queries" wrap the
83 // request context to opt in.
84 pcfg.ConnConfig.Tracer = QueryCounter{}
85
86 openCtx, cancel := context.WithTimeout(ctx, cfg.ConnectTimeout)
87 defer cancel()
88
89 pool, err := pgxpool.NewWithConfig(openCtx, pcfg)
90 if err != nil {
91 return nil, fmt.Errorf("db: open pool: %w", err)
92 }
93
94 // Verify connectivity before returning. A pool that can't talk to
95 // Postgres at startup is a hard failure, not a retry-on-first-query
96 // situation.
97 if err := pool.Ping(openCtx); err != nil {
98 pool.Close()
99 return nil, fmt.Errorf("db: ping: %w", err)
100 }
101
102 return pool, nil
103 }
104
105 // Healthcheck performs a fast SELECT 1 against the pool with a short
106 // timeout. Used by /readyz.
107 func Healthcheck(ctx context.Context, pool *pgxpool.Pool) error {
108 if pool == nil {
109 return errors.New("db: nil pool")
110 }
111 hc, cancel := context.WithTimeout(ctx, 2*time.Second)
112 defer cancel()
113 var v int
114 if err := pool.QueryRow(hc, "SELECT 1").Scan(&v); err != nil {
115 return fmt.Errorf("db: healthcheck: %w", err)
116 }
117 if v != 1 {
118 return fmt.Errorf("db: healthcheck: unexpected scalar %d", v)
119 }
120 return nil
121 }
122
123 // WithTx runs fn inside a Postgres transaction, committing on nil error and
124 // rolling back otherwise. Panics inside fn are recovered and re-raised after
125 // rollback so callers see them as panics rather than silent commits.
126 func WithTx(ctx context.Context, pool *pgxpool.Pool, fn func(pgx.Tx) error) (err error) {
127 tx, err := pool.Begin(ctx)
128 if err != nil {
129 return fmt.Errorf("db: begin: %w", err)
130 }
131 defer func() {
132 if p := recover(); p != nil {
133 _ = tx.Rollback(ctx)
134 panic(p)
135 }
136 if err != nil {
137 if rbErr := tx.Rollback(ctx); rbErr != nil && !errors.Is(rbErr, pgx.ErrTxClosed) {
138 err = fmt.Errorf("%w (rollback: %v)", err, rbErr)
139 }
140 return
141 }
142 if cmErr := tx.Commit(ctx); cmErr != nil {
143 err = fmt.Errorf("db: commit: %w", cmErr)
144 }
145 }()
146 err = fn(tx)
147 return err
148 }
149