tenseleyflow/shithub / 92b65dd

Browse files

Add dbtest harness: per-test DB cloned from migrated template

Authored by mfwolffe <wolffemf@dukes.jmu.edu>
SHA
92b65dde0c48cd41dae6040884473f1f2ce82772
Parents
c5696d0
Tree
35ab3d6

1 changed file

StatusFile+-
A internal/testing/dbtest/dbtest.go 206 0
internal/testing/dbtest/dbtest.goadded
@@ -0,0 +1,206 @@
1
+// SPDX-License-Identifier: AGPL-3.0-or-later
2
+
3
+// Package dbtest provides a parallel-safe test database harness. Each
4
+// caller gets a freshly cloned database from a template that has all
5
+// migrations applied.
6
+//
7
+// Usage:
8
+//
9
+//	func TestSomething(t *testing.T) {
10
+//	    pool := dbtest.NewTestDB(t)
11
+//	    // pool is a *pgxpool.Pool against a freshly cloned DB.
12
+//	}
13
+//
14
+// The harness reads SHITHUB_TEST_DATABASE_URL for the bootstrap connection
15
+// (used to CREATE/DROP databases). Tests are skipped if the env var is not
16
+// set, so unit tests stay green on machines without Postgres.
17
+package dbtest
18
+
19
+import (
20
+	"context"
21
+	"crypto/rand"
22
+	"encoding/hex"
23
+	"fmt"
24
+	"net/url"
25
+	"os"
26
+	"sync"
27
+	"testing"
28
+	"time"
29
+
30
+	"github.com/jackc/pgx/v5"
31
+	"github.com/jackc/pgx/v5/pgxpool"
32
+
33
+	"github.com/tenseleyFlow/shithub/internal/infra/db"
34
+	_ "github.com/tenseleyFlow/shithub/internal/migrationsfs" // register migrations
35
+)
36
+
37
+const envURL = "SHITHUB_TEST_DATABASE_URL"
38
+
39
+var (
40
+	templateOnce  sync.Once
41
+	templateName  string
42
+	templateError error
43
+)
44
+
45
+// NewTestDB returns a *pgxpool.Pool against a freshly created database
46
+// cloned from the per-test-suite template. The database is dropped on
47
+// t.Cleanup. Calls t.Skip if SHITHUB_TEST_DATABASE_URL is unset.
48
+func NewTestDB(t *testing.T) *pgxpool.Pool {
49
+	t.Helper()
50
+	bootURL := os.Getenv(envURL)
51
+	if bootURL == "" {
52
+		t.Skipf("dbtest: %s not set; skipping integration test", envURL)
53
+	}
54
+
55
+	templateOnce.Do(func() {
56
+		templateName, templateError = ensureTemplate(bootURL)
57
+	})
58
+	if templateError != nil {
59
+		t.Fatalf("dbtest: template setup: %v", templateError)
60
+	}
61
+
62
+	dbName := uniqueDBName()
63
+	if err := createFromTemplate(bootURL, dbName, templateName); err != nil {
64
+		t.Fatalf("dbtest: clone db: %v", err)
65
+	}
66
+	t.Cleanup(func() {
67
+		if err := dropDB(bootURL, dbName); err != nil {
68
+			t.Logf("dbtest: drop %s: %v", dbName, err)
69
+		}
70
+	})
71
+
72
+	pool, err := db.Open(context.Background(), db.Config{
73
+		URL:      replaceDBName(bootURL, dbName),
74
+		MaxConns: 4,
75
+		MinConns: 0,
76
+	})
77
+	if err != nil {
78
+		t.Fatalf("dbtest: open clone: %v", err)
79
+	}
80
+	t.Cleanup(pool.Close)
81
+	return pool
82
+}
83
+
84
+// ensureTemplate creates the template database (if absent) and applies all
85
+// migrations to it. Idempotent: subsequent runs reuse the existing template.
86
+func ensureTemplate(bootURL string) (string, error) {
87
+	const name = "shithub_test_template"
88
+	ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
89
+	defer cancel()
90
+
91
+	exists, err := dbExists(ctx, bootURL, name)
92
+	if err != nil {
93
+		return "", err
94
+	}
95
+	if !exists {
96
+		if err := createDB(bootURL, name); err != nil {
97
+			return "", err
98
+		}
99
+	}
100
+
101
+	// Apply migrations to the template.
102
+	tplURL := replaceDBName(bootURL, name)
103
+	if err := db.Migrate(ctx, db.Config{URL: tplURL}, db.MigrateUp); err != nil {
104
+		return "", fmt.Errorf("dbtest: migrate template: %w", err)
105
+	}
106
+
107
+	// Mark the database as a template so CREATE ... TEMPLATE works without
108
+	// requiring superuser. (PG allows a non-template database as TEMPLATE
109
+	// when no other connections exist; marking it explicitly is cleaner.)
110
+	if err := execBoot(bootURL, "ALTER DATABASE "+quoteIdent(name)+" IS_TEMPLATE TRUE"); err != nil {
111
+		// Some Postgres versions/configs reject this; tolerate.
112
+		_ = err
113
+	}
114
+	return name, nil
115
+}
116
+
117
+func dbExists(ctx context.Context, bootURL, name string) (bool, error) {
118
+	conn, err := pgx.Connect(ctx, bootURL)
119
+	if err != nil {
120
+		return false, fmt.Errorf("dbtest: connect: %w", err)
121
+	}
122
+	defer func() { _ = conn.Close(ctx) }()
123
+	var exists bool
124
+	err = conn.QueryRow(ctx, "SELECT EXISTS (SELECT 1 FROM pg_database WHERE datname = $1)", name).Scan(&exists)
125
+	if err != nil {
126
+		return false, fmt.Errorf("dbtest: pg_database query: %w", err)
127
+	}
128
+	return exists, nil
129
+}
130
+
131
+func createDB(bootURL, name string) error {
132
+	return execBoot(bootURL, "CREATE DATABASE "+quoteIdent(name))
133
+}
134
+
135
+func dropDB(bootURL, name string) error {
136
+	// Force-disconnect any leftover sessions; harmless if there are none.
137
+	_ = execBoot(bootURL, "SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = "+quoteLiteral(name))
138
+	return execBoot(bootURL, "DROP DATABASE IF EXISTS "+quoteIdent(name))
139
+}
140
+
141
+func createFromTemplate(bootURL, name, template string) error {
142
+	return execBoot(bootURL, "CREATE DATABASE "+quoteIdent(name)+" TEMPLATE "+quoteIdent(template))
143
+}
144
+
145
+func execBoot(bootURL, sql string) error {
146
+	ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
147
+	defer cancel()
148
+	conn, err := pgx.Connect(ctx, bootURL)
149
+	if err != nil {
150
+		return fmt.Errorf("dbtest: connect: %w", err)
151
+	}
152
+	defer func() { _ = conn.Close(ctx) }()
153
+	_, err = conn.Exec(ctx, sql)
154
+	if err != nil {
155
+		return fmt.Errorf("dbtest: %s: %w", sql, err)
156
+	}
157
+	return nil
158
+}
159
+
160
+// replaceDBName rewrites the path component of a postgres URL to point at a
161
+// different database.
162
+func replaceDBName(bootURL, name string) string {
163
+	u, err := url.Parse(bootURL)
164
+	if err != nil {
165
+		return bootURL
166
+	}
167
+	u.Path = "/" + name
168
+	return u.String()
169
+}
170
+
171
+// uniqueDBName returns a per-test-database name. Hex prefix avoids clashes
172
+// across parallel test runs.
173
+func uniqueDBName() string {
174
+	var b [6]byte
175
+	_, _ = rand.Read(b[:])
176
+	return "shithub_test_" + hex.EncodeToString(b[:])
177
+}
178
+
179
+// quoteIdent wraps an identifier in double quotes and escapes any embedded
180
+// double quotes. For test-DB names we generate the names ourselves so this
181
+// is purely belt-and-braces.
182
+func quoteIdent(s string) string {
183
+	out := make([]byte, 0, len(s)+2)
184
+	out = append(out, '"')
185
+	for i := 0; i < len(s); i++ {
186
+		if s[i] == '"' {
187
+			out = append(out, '"')
188
+		}
189
+		out = append(out, s[i])
190
+	}
191
+	out = append(out, '"')
192
+	return string(out)
193
+}
194
+
195
+func quoteLiteral(s string) string {
196
+	out := make([]byte, 0, len(s)+2)
197
+	out = append(out, '\'')
198
+	for i := 0; i < len(s); i++ {
199
+		if s[i] == '\'' {
200
+			out = append(out, '\'')
201
+		}
202
+		out = append(out, s[i])
203
+	}
204
+	out = append(out, '\'')
205
+	return string(out)
206
+}