tenseleyflow/shithub / 5476341

Browse files

Add WriteAtomic helper with crash-safe tempfile-then-rename + fault-injection test

Authored by mfwolffe <wolffemf@dukes.jmu.edu>
SHA
547634118a419365b4d5b9724a4356a23b492bf8
Parents
75cbeea
Tree
060e623

2 changed files

StatusFile+-
A internal/infra/storage/atomic.go 62 0
A internal/infra/storage/atomic_test.go 124 0
internal/infra/storage/atomic.goadded
@@ -0,0 +1,62 @@
1
+// SPDX-License-Identifier: AGPL-3.0-or-later
2
+
3
+package storage
4
+
5
+import (
6
+	"crypto/rand"
7
+	"encoding/hex"
8
+	"fmt"
9
+	"io"
10
+	"os"
11
+	"path/filepath"
12
+)
13
+
14
+// WriteAtomic writes src to path via a tempfile in the same directory,
15
+// fsyncs, and renames. A crash between write and rename leaves the temp
16
+// file behind (callers may sweep these on startup) but never a partial
17
+// file at path.
18
+//
19
+// The temp file MUST live on the same mount as path so the rename is
20
+// atomic — callers should not pass paths that cross mount points.
21
+func WriteAtomic(path string, src io.Reader) error {
22
+	dir := filepath.Dir(path)
23
+	suffix, err := randomSuffix()
24
+	if err != nil {
25
+		return fmt.Errorf("storage: atomic: random suffix: %w", err)
26
+	}
27
+	tmp := filepath.Join(dir, "."+filepath.Base(path)+".tmp."+suffix)
28
+
29
+	f, err := os.OpenFile(tmp, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0o600)
30
+	if err != nil {
31
+		return fmt.Errorf("storage: atomic: open temp: %w", err)
32
+	}
33
+	cleanup := func() { _ = os.Remove(tmp) }
34
+
35
+	if _, err := io.Copy(f, src); err != nil {
36
+		_ = f.Close()
37
+		cleanup()
38
+		return fmt.Errorf("storage: atomic: copy: %w", err)
39
+	}
40
+	if err := f.Sync(); err != nil {
41
+		_ = f.Close()
42
+		cleanup()
43
+		return fmt.Errorf("storage: atomic: fsync: %w", err)
44
+	}
45
+	if err := f.Close(); err != nil {
46
+		cleanup()
47
+		return fmt.Errorf("storage: atomic: close: %w", err)
48
+	}
49
+	if err := os.Rename(tmp, path); err != nil {
50
+		cleanup()
51
+		return fmt.Errorf("storage: atomic: rename: %w", err)
52
+	}
53
+	return nil
54
+}
55
+
56
+func randomSuffix() (string, error) {
57
+	var b [8]byte
58
+	if _, err := rand.Read(b[:]); err != nil {
59
+		return "", err
60
+	}
61
+	return hex.EncodeToString(b[:]), nil
62
+}
internal/infra/storage/atomic_test.goadded
@@ -0,0 +1,124 @@
1
+// SPDX-License-Identifier: AGPL-3.0-or-later
2
+
3
+package storage
4
+
5
+import (
6
+	"bytes"
7
+	"errors"
8
+	"io"
9
+	"os"
10
+	"path/filepath"
11
+	"strings"
12
+	"testing"
13
+)
14
+
15
+func TestWriteAtomic_HappyPath(t *testing.T) {
16
+	t.Parallel()
17
+	dir := t.TempDir()
18
+	path := filepath.Join(dir, "out.txt")
19
+	body := "hello atomic"
20
+	if err := WriteAtomic(path, strings.NewReader(body)); err != nil {
21
+		t.Fatalf("WriteAtomic: %v", err)
22
+	}
23
+	got, err := os.ReadFile(path)
24
+	if err != nil {
25
+		t.Fatalf("ReadFile: %v", err)
26
+	}
27
+	if string(got) != body {
28
+		t.Fatalf("got %q, want %q", got, body)
29
+	}
30
+}
31
+
32
+// failingReader returns N bytes then an error. Used to inject a failure
33
+// after partial write — WriteAtomic must NOT leave a file at path.
34
+type failingReader struct {
35
+	data    []byte
36
+	off     int
37
+	failAt  int
38
+	failErr error
39
+}
40
+
41
+func (f *failingReader) Read(p []byte) (int, error) {
42
+	if f.off >= f.failAt {
43
+		return 0, f.failErr
44
+	}
45
+	n := copy(p, f.data[f.off:])
46
+	if f.off+n > f.failAt {
47
+		n = f.failAt - f.off
48
+	}
49
+	f.off += n
50
+	if f.off >= f.failAt {
51
+		return n, f.failErr
52
+	}
53
+	return n, nil
54
+}
55
+
56
+func TestWriteAtomic_PartialWriteLeavesNoFile(t *testing.T) {
57
+	t.Parallel()
58
+	dir := t.TempDir()
59
+	path := filepath.Join(dir, "should-not-exist.txt")
60
+
61
+	r := &failingReader{
62
+		data:    bytes.Repeat([]byte("x"), 1024),
63
+		failAt:  256,
64
+		failErr: errors.New("simulated crash"),
65
+	}
66
+	err := WriteAtomic(path, r)
67
+	if err == nil {
68
+		t.Fatal("expected error from failing reader")
69
+	}
70
+
71
+	// Destination must not exist.
72
+	if _, err := os.Stat(path); !os.IsNotExist(err) {
73
+		t.Fatalf("destination exists after failed atomic write: stat err=%v", err)
74
+	}
75
+
76
+	// No leftover .tmp.* files in the same directory either.
77
+	entries, err := os.ReadDir(dir)
78
+	if err != nil {
79
+		t.Fatalf("ReadDir: %v", err)
80
+	}
81
+	for _, e := range entries {
82
+		if strings.Contains(e.Name(), ".tmp.") {
83
+			t.Fatalf("leftover temp file: %s", e.Name())
84
+		}
85
+	}
86
+}
87
+
88
+func TestWriteAtomic_OverwritesExisting(t *testing.T) {
89
+	t.Parallel()
90
+	dir := t.TempDir()
91
+	path := filepath.Join(dir, "out.txt")
92
+	if err := os.WriteFile(path, []byte("old content"), 0o600); err != nil {
93
+		t.Fatalf("seed: %v", err)
94
+	}
95
+	if err := WriteAtomic(path, strings.NewReader("new content")); err != nil {
96
+		t.Fatalf("WriteAtomic: %v", err)
97
+	}
98
+	got, _ := os.ReadFile(path)
99
+	if string(got) != "new content" {
100
+		t.Fatalf("got %q, want %q", got, "new content")
101
+	}
102
+}
103
+
104
+func TestWriteAtomic_LargeBody(t *testing.T) {
105
+	t.Parallel()
106
+	dir := t.TempDir()
107
+	path := filepath.Join(dir, "big.bin")
108
+	body := bytes.Repeat([]byte{0xab}, 5*1024*1024) // 5 MiB
109
+	if err := WriteAtomic(path, bytes.NewReader(body)); err != nil {
110
+		t.Fatalf("WriteAtomic: %v", err)
111
+	}
112
+	f, err := os.Open(path)
113
+	if err != nil {
114
+		t.Fatalf("Open: %v", err)
115
+	}
116
+	defer func() { _ = f.Close() }()
117
+	h, err := io.ReadAll(f)
118
+	if err != nil {
119
+		t.Fatalf("ReadAll: %v", err)
120
+	}
121
+	if !bytes.Equal(h, body) {
122
+		t.Fatalf("body mismatch (len got=%d want=%d)", len(h), len(body))
123
+	}
124
+}