Go · 5839 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 // Package dbtest provides a parallel-safe test database harness. Each
4 // caller gets a freshly cloned database from a template that has all
5 // migrations applied.
6 //
7 // Usage:
8 //
9 // func TestSomething(t *testing.T) {
10 // pool := dbtest.NewTestDB(t)
11 // // pool is a *pgxpool.Pool against a freshly cloned DB.
12 // }
13 //
14 // The harness reads SHITHUB_TEST_DATABASE_URL for the bootstrap connection
15 // (used to CREATE/DROP databases). Tests are skipped if the env var is not
16 // set, so unit tests stay green on machines without Postgres.
17 package dbtest
18
19 import (
20 "context"
21 "crypto/rand"
22 "encoding/hex"
23 "fmt"
24 "net/url"
25 "os"
26 "sync"
27 "testing"
28 "time"
29
30 "github.com/jackc/pgx/v5"
31 "github.com/jackc/pgx/v5/pgxpool"
32
33 "github.com/tenseleyFlow/shithub/internal/infra/db"
34 _ "github.com/tenseleyFlow/shithub/internal/migrationsfs" // register migrations
35 )
36
37 const envURL = "SHITHUB_TEST_DATABASE_URL"
38
39 var (
40 templateOnce sync.Once
41 templateName string
42 templateError error
43 )
44
45 // NewTestDB returns a *pgxpool.Pool against a freshly created database
46 // cloned from the per-test-suite template. The database is dropped on
47 // t.Cleanup. Calls t.Skip if SHITHUB_TEST_DATABASE_URL is unset.
48 func NewTestDB(t *testing.T) *pgxpool.Pool {
49 t.Helper()
50 bootURL := os.Getenv(envURL)
51 if bootURL == "" {
52 t.Skipf("dbtest: %s not set; skipping integration test", envURL)
53 }
54
55 templateOnce.Do(func() {
56 templateName, templateError = ensureTemplate(bootURL)
57 })
58 if templateError != nil {
59 t.Fatalf("dbtest: template setup: %v", templateError)
60 }
61
62 dbName := uniqueDBName()
63 if err := createFromTemplate(bootURL, dbName, templateName); err != nil {
64 t.Fatalf("dbtest: clone db: %v", err)
65 }
66 t.Cleanup(func() {
67 if err := dropDB(bootURL, dbName); err != nil {
68 t.Logf("dbtest: drop %s: %v", dbName, err)
69 }
70 })
71
72 pool, err := db.Open(context.Background(), db.Config{
73 URL: replaceDBName(bootURL, dbName),
74 MaxConns: 4,
75 MinConns: 0,
76 })
77 if err != nil {
78 t.Fatalf("dbtest: open clone: %v", err)
79 }
80 t.Cleanup(pool.Close)
81 return pool
82 }
83
84 // ensureTemplate creates the template database (if absent) and applies all
85 // migrations to it. Idempotent: subsequent runs reuse the existing template.
86 func ensureTemplate(bootURL string) (string, error) {
87 const name = "shithub_test_template"
88 ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
89 defer cancel()
90
91 exists, err := dbExists(ctx, bootURL, name)
92 if err != nil {
93 return "", err
94 }
95 if !exists {
96 if err := createDB(bootURL, name); err != nil {
97 return "", err
98 }
99 }
100
101 // Apply migrations to the template.
102 tplURL := replaceDBName(bootURL, name)
103 if err := db.Migrate(ctx, db.Config{URL: tplURL}, db.MigrateUp); err != nil {
104 return "", fmt.Errorf("dbtest: migrate template: %w", err)
105 }
106
107 // Mark the database as a template so CREATE ... TEMPLATE works without
108 // requiring superuser. (PG allows a non-template database as TEMPLATE
109 // when no other connections exist; marking it explicitly is cleaner.)
110 if err := execBoot(bootURL, "ALTER DATABASE "+quoteIdent(name)+" IS_TEMPLATE TRUE"); err != nil {
111 // Some Postgres versions/configs reject this; tolerate.
112 _ = err
113 }
114 return name, nil
115 }
116
117 func dbExists(ctx context.Context, bootURL, name string) (bool, error) {
118 conn, err := pgx.Connect(ctx, bootURL)
119 if err != nil {
120 return false, fmt.Errorf("dbtest: connect: %w", err)
121 }
122 defer func() { _ = conn.Close(ctx) }()
123 var exists bool
124 err = conn.QueryRow(ctx, "SELECT EXISTS (SELECT 1 FROM pg_database WHERE datname = $1)", name).Scan(&exists)
125 if err != nil {
126 return false, fmt.Errorf("dbtest: pg_database query: %w", err)
127 }
128 return exists, nil
129 }
130
131 func createDB(bootURL, name string) error {
132 return execBoot(bootURL, "CREATE DATABASE "+quoteIdent(name))
133 }
134
135 func dropDB(bootURL, name string) error {
136 // Force-disconnect any leftover sessions; harmless if there are none.
137 _ = execBoot(bootURL, "SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = "+quoteLiteral(name))
138 return execBoot(bootURL, "DROP DATABASE IF EXISTS "+quoteIdent(name))
139 }
140
141 func createFromTemplate(bootURL, name, template string) error {
142 return execBoot(bootURL, "CREATE DATABASE "+quoteIdent(name)+" TEMPLATE "+quoteIdent(template))
143 }
144
145 func execBoot(bootURL, sql string) error {
146 ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
147 defer cancel()
148 conn, err := pgx.Connect(ctx, bootURL)
149 if err != nil {
150 return fmt.Errorf("dbtest: connect: %w", err)
151 }
152 defer func() { _ = conn.Close(ctx) }()
153 _, err = conn.Exec(ctx, sql)
154 if err != nil {
155 return fmt.Errorf("dbtest: %s: %w", sql, err)
156 }
157 return nil
158 }
159
160 // replaceDBName rewrites the path component of a postgres URL to point at a
161 // different database.
162 func replaceDBName(bootURL, name string) string {
163 u, err := url.Parse(bootURL)
164 if err != nil {
165 return bootURL
166 }
167 u.Path = "/" + name
168 return u.String()
169 }
170
171 // uniqueDBName returns a per-test-database name. Hex prefix avoids clashes
172 // across parallel test runs.
173 func uniqueDBName() string {
174 var b [6]byte
175 _, _ = rand.Read(b[:])
176 return "shithub_test_" + hex.EncodeToString(b[:])
177 }
178
179 // quoteIdent wraps an identifier in double quotes and escapes any embedded
180 // double quotes. For test-DB names we generate the names ourselves so this
181 // is purely belt-and-braces.
182 func quoteIdent(s string) string {
183 out := make([]byte, 0, len(s)+2)
184 out = append(out, '"')
185 for i := 0; i < len(s); i++ {
186 if s[i] == '"' {
187 out = append(out, '"')
188 }
189 out = append(out, s[i])
190 }
191 out = append(out, '"')
192 return string(out)
193 }
194
195 func quoteLiteral(s string) string {
196 out := make([]byte, 0, len(s)+2)
197 out = append(out, '\'')
198 for i := 0; i < len(s); i++ {
199 if s[i] == '\'' {
200 out = append(out, '\'')
201 }
202 out = append(out, s[i])
203 }
204 out = append(out, '\'')
205 return string(out)
206 }
207