| 1 | // SPDX-License-Identifier: AGPL-3.0-or-later |
| 2 | |
| 3 | // Package db owns the Postgres connection lifecycle. S01 ships the |
| 4 | // open/healthcheck/transaction helpers; later sprints add domain-specific |
| 5 | // wrappers but always go through this package for the pool. |
| 6 | package db |
| 7 | |
| 8 | import ( |
| 9 | "context" |
| 10 | "errors" |
| 11 | "fmt" |
| 12 | "os" |
| 13 | "time" |
| 14 | |
| 15 | "github.com/jackc/pgx/v5" |
| 16 | "github.com/jackc/pgx/v5/pgxpool" |
| 17 | ) |
| 18 | |
| 19 | // Config carries the pool's connection settings. S03 will populate this from |
| 20 | // the layered config loader; for S01 we accept an optional explicit URL or |
| 21 | // fall back to SHITHUB_DATABASE_URL. |
| 22 | type Config struct { |
| 23 | URL string |
| 24 | MaxConns int32 |
| 25 | MinConns int32 |
| 26 | ConnectTimeout time.Duration |
| 27 | StatementCancel time.Duration |
| 28 | } |
| 29 | |
| 30 | // Defaults returns sensible defaults for a dev pool. Prod values land via |
| 31 | // the config loader in S03. |
| 32 | func Defaults() Config { |
| 33 | return Config{ |
| 34 | MaxConns: 10, |
| 35 | MinConns: 1, |
| 36 | ConnectTimeout: 5 * time.Second, |
| 37 | StatementCancel: 30 * time.Second, |
| 38 | } |
| 39 | } |
| 40 | |
| 41 | // Resolve fills in URL from env if missing and clamps numeric defaults. |
| 42 | func (c Config) Resolve() Config { |
| 43 | if c.URL == "" { |
| 44 | c.URL = os.Getenv("SHITHUB_DATABASE_URL") |
| 45 | } |
| 46 | if c.MaxConns <= 0 { |
| 47 | c.MaxConns = 10 |
| 48 | } |
| 49 | if c.MinConns < 0 { |
| 50 | c.MinConns = 0 |
| 51 | } |
| 52 | if c.ConnectTimeout <= 0 { |
| 53 | c.ConnectTimeout = 5 * time.Second |
| 54 | } |
| 55 | if c.StatementCancel <= 0 { |
| 56 | c.StatementCancel = 30 * time.Second |
| 57 | } |
| 58 | return c |
| 59 | } |
| 60 | |
| 61 | // ErrNoURL is returned by Open when neither cfg.URL nor SHITHUB_DATABASE_URL |
| 62 | // is set. |
| 63 | var ErrNoURL = errors.New("db: no DATABASE_URL configured (set SHITHUB_DATABASE_URL)") |
| 64 | |
| 65 | // Open creates a new pgx pool from cfg. The caller owns the pool's lifecycle |
| 66 | // and must call pool.Close() on shutdown. |
| 67 | func Open(ctx context.Context, cfg Config) (*pgxpool.Pool, error) { |
| 68 | cfg = cfg.Resolve() |
| 69 | if cfg.URL == "" { |
| 70 | return nil, ErrNoURL |
| 71 | } |
| 72 | |
| 73 | pcfg, err := pgxpool.ParseConfig(cfg.URL) |
| 74 | if err != nil { |
| 75 | return nil, fmt.Errorf("db: parse config: %w", err) |
| 76 | } |
| 77 | pcfg.MaxConns = cfg.MaxConns |
| 78 | pcfg.MinConns = cfg.MinConns |
| 79 | pcfg.ConnConfig.ConnectTimeout = cfg.ConnectTimeout |
| 80 | // QueryCounter is a no-op when the request context wasn't built |
| 81 | // with WithCounter — production traffic pays one map lookup per |
| 82 | // query. Tests that assert "this route does ≤ N queries" wrap the |
| 83 | // request context to opt in. |
| 84 | pcfg.ConnConfig.Tracer = QueryCounter{} |
| 85 | |
| 86 | openCtx, cancel := context.WithTimeout(ctx, cfg.ConnectTimeout) |
| 87 | defer cancel() |
| 88 | |
| 89 | pool, err := pgxpool.NewWithConfig(openCtx, pcfg) |
| 90 | if err != nil { |
| 91 | return nil, fmt.Errorf("db: open pool: %w", err) |
| 92 | } |
| 93 | |
| 94 | // Verify connectivity before returning. A pool that can't talk to |
| 95 | // Postgres at startup is a hard failure, not a retry-on-first-query |
| 96 | // situation. |
| 97 | if err := pool.Ping(openCtx); err != nil { |
| 98 | pool.Close() |
| 99 | return nil, fmt.Errorf("db: ping: %w", err) |
| 100 | } |
| 101 | |
| 102 | return pool, nil |
| 103 | } |
| 104 | |
| 105 | // Healthcheck performs a fast SELECT 1 against the pool with a short |
| 106 | // timeout. Used by /readyz. |
| 107 | func Healthcheck(ctx context.Context, pool *pgxpool.Pool) error { |
| 108 | if pool == nil { |
| 109 | return errors.New("db: nil pool") |
| 110 | } |
| 111 | hc, cancel := context.WithTimeout(ctx, 2*time.Second) |
| 112 | defer cancel() |
| 113 | var v int |
| 114 | if err := pool.QueryRow(hc, "SELECT 1").Scan(&v); err != nil { |
| 115 | return fmt.Errorf("db: healthcheck: %w", err) |
| 116 | } |
| 117 | if v != 1 { |
| 118 | return fmt.Errorf("db: healthcheck: unexpected scalar %d", v) |
| 119 | } |
| 120 | return nil |
| 121 | } |
| 122 | |
| 123 | // WithTx runs fn inside a Postgres transaction, committing on nil error and |
| 124 | // rolling back otherwise. Panics inside fn are recovered and re-raised after |
| 125 | // rollback so callers see them as panics rather than silent commits. |
| 126 | func WithTx(ctx context.Context, pool *pgxpool.Pool, fn func(pgx.Tx) error) (err error) { |
| 127 | tx, err := pool.Begin(ctx) |
| 128 | if err != nil { |
| 129 | return fmt.Errorf("db: begin: %w", err) |
| 130 | } |
| 131 | defer func() { |
| 132 | if p := recover(); p != nil { |
| 133 | _ = tx.Rollback(ctx) |
| 134 | panic(p) |
| 135 | } |
| 136 | if err != nil { |
| 137 | if rbErr := tx.Rollback(ctx); rbErr != nil && !errors.Is(rbErr, pgx.ErrTxClosed) { |
| 138 | err = fmt.Errorf("%w (rollback: %v)", err, rbErr) |
| 139 | } |
| 140 | return |
| 141 | } |
| 142 | if cmErr := tx.Commit(ctx); cmErr != nil { |
| 143 | err = fmt.Errorf("db: commit: %w", cmErr) |
| 144 | } |
| 145 | }() |
| 146 | err = fn(tx) |
| 147 | return err |
| 148 | } |
| 149 |