tenseleyflow/documentlanguagemodel / 6a4d880

Browse files

test(fixtures): add golden-output registry keyed by (name, torch_version)

Authored by espadonne
SHA
6a4d88042eae2b3322750f6a83e9033d86e36df7
Parents
949aa1b
Tree
573248a

1 changed file

StatusFile+-
A tests/fixtures/golden.py 104 0
tests/fixtures/golden.pyadded
@@ -0,0 +1,104 @@
1
+"""Golden-output registry for determinism tests.
2
+
3
+Goldens are JSON blobs under `tests/golden/` keyed by `(name, torch_version)`
4
+so bumping torch produces a new golden rather than silently invalidating
5
+the old one (audit F19).
6
+
7
+Contract:
8
+
9
+- `assert_golden(actual, name)` compares `actual` against the stored
10
+  golden; if no golden exists for the current torch version, the test
11
+  fails with an instruction to regenerate.
12
+- `pytest --update-goldens` regenerates instead of asserting; the
13
+  `update_goldens` fixture (`conftest.py`) surfaces the flag.
14
+- Golden content is deterministically serialized JSON (sorted keys,
15
+  normalized floats).
16
+"""
17
+
18
+from __future__ import annotations
19
+
20
+import json
21
+from pathlib import Path
22
+from typing import Any
23
+
24
+GOLDEN_ROOT = Path(__file__).resolve().parent.parent / "golden"
25
+
26
+
27
+class MissingGoldenError(AssertionError):
28
+    """Raised when a golden isn't on disk and we're not regenerating."""
29
+
30
+
31
+def golden_path(name: str, *, torch_version: str | None = None) -> Path:
32
+    """Resolve the on-disk path for a golden by (name, torch_version)."""
33
+    tv = torch_version if torch_version is not None else _current_torch_version()
34
+    # Filesystem-safe: replace periods and pluses.
35
+    safe_tv = tv.replace("/", "_").replace("+", "_").replace(" ", "_")
36
+    safe_name = name.replace("/", "_")
37
+    return GOLDEN_ROOT / f"{safe_name}.torch-{safe_tv}.json"
38
+
39
+
40
+def load_golden(name: str, *, torch_version: str | None = None) -> Any:
41
+    path = golden_path(name, torch_version=torch_version)
42
+    if not path.exists():
43
+        raise MissingGoldenError(
44
+            f"No golden for {name!r} at {path}. "
45
+            "Run with --update-goldens to regenerate after manual review.",
46
+        )
47
+    return json.loads(path.read_text(encoding="utf-8"))
48
+
49
+
50
+def assert_golden(
51
+    actual: Any,
52
+    name: str,
53
+    *,
54
+    update: bool = False,
55
+    torch_version: str | None = None,
56
+) -> None:
57
+    """Compare `actual` against stored golden, or regenerate when `update`.
58
+
59
+    `actual` must be JSON-serializable. Regeneration writes sorted-keys
60
+    output so diffs are reviewable.
61
+    """
62
+    path = golden_path(name, torch_version=torch_version)
63
+
64
+    if update:
65
+        path.parent.mkdir(parents=True, exist_ok=True)
66
+        _write_json(path, actual)
67
+        return
68
+
69
+    if not path.exists():
70
+        raise MissingGoldenError(
71
+            f"No golden for {name!r} at {path}. "
72
+            "Run `pytest --update-goldens` after manual review to create it.",
73
+        )
74
+
75
+    expected = json.loads(path.read_text(encoding="utf-8"))
76
+    if _canonical(expected) != _canonical(actual):
77
+        diff_lines = [
78
+            f"Golden mismatch for {name!r} (path: {path}).",
79
+            "Expected:",
80
+            json.dumps(expected, sort_keys=True, indent=2),
81
+            "Actual:",
82
+            json.dumps(actual, sort_keys=True, indent=2),
83
+        ]
84
+        raise AssertionError("\n".join(diff_lines))
85
+
86
+
87
+# --- internals ---------------------------------------------------------------
88
+
89
+
90
+def _current_torch_version() -> str:
91
+    # Lazy import so callers that don't use goldens don't pay for torch.
92
+    import torch
93
+
94
+    return torch.__version__
95
+
96
+
97
+def _canonical(value: Any) -> str:
98
+    """Canonical JSON form for comparison."""
99
+    return json.dumps(value, sort_keys=True, separators=(",", ":"))
100
+
101
+
102
+def _write_json(path: Path, value: Any) -> None:
103
+    text = json.dumps(value, sort_keys=True, indent=2) + "\n"
104
+    path.write_text(text, encoding="utf-8")