Go · 2892 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 package storage
4
5 import (
6 "bytes"
7 "errors"
8 "io"
9 "os"
10 "path/filepath"
11 "strings"
12 "testing"
13 )
14
15 func TestWriteAtomic_HappyPath(t *testing.T) {
16 t.Parallel()
17 dir := t.TempDir()
18 path := filepath.Join(dir, "out.txt")
19 body := "hello atomic"
20 if err := WriteAtomic(path, strings.NewReader(body)); err != nil {
21 t.Fatalf("WriteAtomic: %v", err)
22 }
23 got, err := os.ReadFile(path)
24 if err != nil {
25 t.Fatalf("ReadFile: %v", err)
26 }
27 if string(got) != body {
28 t.Fatalf("got %q, want %q", got, body)
29 }
30 }
31
32 // failingReader returns N bytes then an error. Used to inject a failure
33 // after partial write — WriteAtomic must NOT leave a file at path.
34 type failingReader struct {
35 data []byte
36 off int
37 failAt int
38 failErr error
39 }
40
41 func (f *failingReader) Read(p []byte) (int, error) {
42 if f.off >= f.failAt {
43 return 0, f.failErr
44 }
45 n := copy(p, f.data[f.off:])
46 if f.off+n > f.failAt {
47 n = f.failAt - f.off
48 }
49 f.off += n
50 if f.off >= f.failAt {
51 return n, f.failErr
52 }
53 return n, nil
54 }
55
56 func TestWriteAtomic_PartialWriteLeavesNoFile(t *testing.T) {
57 t.Parallel()
58 dir := t.TempDir()
59 path := filepath.Join(dir, "should-not-exist.txt")
60
61 r := &failingReader{
62 data: bytes.Repeat([]byte("x"), 1024),
63 failAt: 256,
64 failErr: errors.New("simulated crash"),
65 }
66 err := WriteAtomic(path, r)
67 if err == nil {
68 t.Fatal("expected error from failing reader")
69 }
70
71 // Destination must not exist.
72 if _, err := os.Stat(path); !os.IsNotExist(err) {
73 t.Fatalf("destination exists after failed atomic write: stat err=%v", err)
74 }
75
76 // No leftover .tmp.* files in the same directory either.
77 entries, err := os.ReadDir(dir)
78 if err != nil {
79 t.Fatalf("ReadDir: %v", err)
80 }
81 for _, e := range entries {
82 if strings.Contains(e.Name(), ".tmp.") {
83 t.Fatalf("leftover temp file: %s", e.Name())
84 }
85 }
86 }
87
88 func TestWriteAtomic_OverwritesExisting(t *testing.T) {
89 t.Parallel()
90 dir := t.TempDir()
91 path := filepath.Join(dir, "out.txt")
92 if err := os.WriteFile(path, []byte("old content"), 0o600); err != nil {
93 t.Fatalf("seed: %v", err)
94 }
95 if err := WriteAtomic(path, strings.NewReader("new content")); err != nil {
96 t.Fatalf("WriteAtomic: %v", err)
97 }
98 got, _ := os.ReadFile(path)
99 if string(got) != "new content" {
100 t.Fatalf("got %q, want %q", got, "new content")
101 }
102 }
103
104 func TestWriteAtomic_LargeBody(t *testing.T) {
105 t.Parallel()
106 dir := t.TempDir()
107 path := filepath.Join(dir, "big.bin")
108 body := bytes.Repeat([]byte{0xab}, 5*1024*1024) // 5 MiB
109 if err := WriteAtomic(path, bytes.NewReader(body)); err != nil {
110 t.Fatalf("WriteAtomic: %v", err)
111 }
112 f, err := os.Open(path)
113 if err != nil {
114 t.Fatalf("Open: %v", err)
115 }
116 defer func() { _ = f.Close() }()
117 h, err := io.ReadAll(f)
118 if err != nil {
119 t.Fatalf("ReadAll: %v", err)
120 }
121 if !bytes.Equal(h, body) {
122 t.Fatalf("body mismatch (len got=%d want=%d)", len(h), len(body))
123 }
124 }
125