tenseleyflow/documentlanguagemodel / 233d743

Browse files

feat(train): disk preflight + two-phase checkpoint commit (audit F12)

Authored by mfwolffe <wolffemf@dukes.jmu.edu>
SHA
233d743f0a0df6aef3d7ad284c060a48a9cb23d1
Parents
ae2aea1
Tree
d74479a

2 changed files

StatusFile+-
A src/dlm/train/checkpoint_commit.py 143 0
A src/dlm/train/disk_preflight.py 84 0
src/dlm/train/checkpoint_commit.pyadded
@@ -0,0 +1,143 @@
1
+"""Two-phase checkpoint commit (audit F12).
2
+
3
+Invariant
4
+---------
5
+
6
+`adapter/current.txt` **never** points at a half-written version
7
+directory. Either the old version is still authoritative, or the new
8
+one is — never a partial state.
9
+
10
+Lifecycle
11
+---------
12
+
13
+1. `allocate_next_version(store)` picks `vNNNN` where `NNNN` is one
14
+   above the highest existing version directory (or `0001` on a fresh
15
+   store). Creates the empty directory.
16
+2. Caller populates the directory — `adapter.save_pretrained()` writes
17
+   the adapter config + weights; `state_sidecar.save_state()` writes
18
+   `training_state.pt` + sha256.
19
+3. `fsync_dir(path)` flushes the directory entry to disk.
20
+4. `store.set_current_adapter(path)` atomically flips the pointer via
21
+   `os.replace` on a tmp file (already implemented in Sprint 04).
22
+
23
+The `commit_version()` helper bundles steps 1 + 3 + 4 around a
24
+caller-supplied writer function, so the "happy path" is one call. If
25
+the writer raises, the pending directory is *not* made current — it's
26
+left in place so the caller can inspect / clean up / retry.
27
+"""
28
+
29
+from __future__ import annotations
30
+
31
+import os
32
+from collections.abc import Callable
33
+from pathlib import Path
34
+from typing import TYPE_CHECKING
35
+
36
+if TYPE_CHECKING:
37
+    from dlm.store.paths import StorePath
38
+
39
+# Regex-safe prefix shared with `StorePath.adapter_version`.
40
+_VERSION_PREFIX = "v"
41
+
42
+
43
+def allocate_next_version(store: StorePath) -> Path:
44
+    """Return the next empty `adapter/versions/vNNNN/` path.
45
+
46
+    Creates the directory (and any missing parents). `StorePath.adapter_version`
47
+    does *not* create; we do it here so callers can start writing
48
+    immediately.
49
+    """
50
+    existing = _existing_versions(store)
51
+    next_n = (max(existing) if existing else 0) + 1
52
+    version_dir = store.adapter_version(next_n)
53
+    version_dir.mkdir(parents=True, exist_ok=False)
54
+    return version_dir
55
+
56
+
57
+def commit_version(
58
+    store: StorePath,
59
+    writer: Callable[[Path], None],
60
+) -> Path:
61
+    """Allocate → populate → fsync → flip pointer.
62
+
63
+    Returns the committed version directory. On writer exception:
64
+    - the pending directory is left on disk (not cleaned up, so the
65
+      caller can diagnose)
66
+    - the current pointer is NOT updated
67
+    - the exception propagates
68
+    """
69
+    pending = allocate_next_version(store)
70
+    try:
71
+        writer(pending)
72
+    except BaseException:
73
+        # Leave `pending` on disk; the next allocate_next_version call
74
+        # skips over it by bumping NNNN. Cleanup is a caller concern.
75
+        raise
76
+
77
+    fsync_dir(pending)
78
+    store.set_current_adapter(pending)
79
+    return pending
80
+
81
+
82
+def fsync_dir(path: Path) -> None:
83
+    """Flush the directory entry for `path` to disk.
84
+
85
+    After writing the adapter files we need to ensure the directory
86
+    metadata (new file entries) survives a power loss. `os.fsync` on
87
+    the directory fd is the POSIX idiom. Windows doesn't allow opening
88
+    a directory handle for fsync; on Windows we no-op (the underlying
89
+    filesystem typically metadata-journals).
90
+    """
91
+    if os.name == "nt":  # pragma: no cover — macOS/Linux covered
92
+        return
93
+    fd = os.open(str(path), os.O_RDONLY)
94
+    try:
95
+        os.fsync(fd)
96
+    finally:
97
+        os.close(fd)
98
+
99
+
100
+def list_pending_versions(store: StorePath) -> list[Path]:
101
+    """Return version dirs that exist on disk but aren't the current pointer.
102
+
103
+    Used by the trainer's startup routine to detect crash-before-flip
104
+    remnants: if the pending dir has a complete adapter + training_state
105
+    + matching sha256, the user could in principle resume from it by
106
+    manually flipping the pointer. We surface them rather than
107
+    auto-deleting.
108
+    """
109
+    existing = _existing_versions(store)
110
+    current = store.resolve_current_adapter()
111
+    current_n = _parse_version_number(current) if current is not None else None
112
+    return [
113
+        store.adapter_version(n)
114
+        for n in sorted(existing)
115
+        if n != current_n
116
+    ]
117
+
118
+
119
+def _existing_versions(store: StorePath) -> list[int]:
120
+    base = store.adapter_versions
121
+    if not base.is_dir():
122
+        return []
123
+    out: list[int] = []
124
+    for entry in base.iterdir():
125
+        if not entry.is_dir():
126
+            continue
127
+        n = _parse_version_dirname(entry.name)
128
+        if n is not None:
129
+            out.append(n)
130
+    return out
131
+
132
+
133
+def _parse_version_dirname(name: str) -> int | None:
134
+    if not name.startswith(_VERSION_PREFIX):
135
+        return None
136
+    try:
137
+        return int(name[len(_VERSION_PREFIX) :])
138
+    except ValueError:
139
+        return None
140
+
141
+
142
+def _parse_version_number(path: Path) -> int | None:
143
+    return _parse_version_dirname(path.name)
src/dlm/train/disk_preflight.pyadded
@@ -0,0 +1,84 @@
1
+"""Pre-train disk-space check (audit F12).
2
+
3
+A training run that fails halfway through because the disk filled up
4
+leaves the store in a confusing state (half-written checkpoint, partial
5
+log). Catch it up-front: estimate the bytes we're about to write, add
6
+a safety margin, and refuse to start if the filesystem doesn't have
7
+headroom.
8
+
9
+The estimate is deliberately pessimistic — LoRA adapters are usually
10
+<100 MB, but the checkpoint also has to hold the torch-serialized
11
+optimizer state (which is ~2× the adapter for AdamW + scaler), logs,
12
+and any cached evaluation artifacts. We also account for replay-corpus
13
+growth (one snapshot per new section, zstd-compressed).
14
+
15
+This module doesn't know about the heavy HF stack; it works from
16
+`BaseModelSpec` + `TrainingPlan` + simple arithmetic. Heavy math lives
17
+in `dlm.hardware.memory`; here we only need a byte estimate.
18
+"""
19
+
20
+from __future__ import annotations
21
+
22
+import shutil
23
+from pathlib import Path
24
+from typing import TYPE_CHECKING
25
+
26
+from dlm.train.errors import DiskSpaceError
27
+
28
+if TYPE_CHECKING:
29
+    from dlm.base_models import BaseModelSpec
30
+    from dlm.hardware.plan import TrainingPlan
31
+
32
+# Floor estimates (bytes). Generous by design; training "fails fast on
33
+# low disk" is a much worse UX than "warns conservatively".
34
+_LOG_RESERVE = 10 * 1024 * 1024  # 10 MB per run for JSONL logs
35
+_OPTIMIZER_MULTIPLIER = 2.5  # AdamW + scaler + schedule state vs adapter bytes
36
+_ADAPTER_FLOOR = 50 * 1024 * 1024  # 50 MB minimum adapter size (conservative)
37
+
38
+
39
+def estimate_checkpoint_bytes(spec: BaseModelSpec, plan: TrainingPlan) -> int:
40
+    """Rough byte estimate for one full checkpoint commit.
41
+
42
+    Components:
43
+    - Adapter weights (LoRA rank × shapes; modeled as a fraction of base)
44
+    - Optimizer state (AdamW keeps m + v per trainable param)
45
+    - Scaler + scheduler state
46
+    - Log reserve
47
+    """
48
+    # LoRA adds roughly `r × (in + out)` params per target module. For a
49
+    # canonical rank-16 adapter on a 1.5B model this works out to ~50 MB
50
+    # in fp16 / ~100 MB in fp32. The estimate treats the adapter as 1%
51
+    # of the base size, clamped to the floor.
52
+    base_bytes = int(spec.size_gb_fp16 * (1024**3))
53
+    adapter_bytes = max(base_bytes // 100, _ADAPTER_FLOOR)
54
+    optimizer_bytes = int(adapter_bytes * _OPTIMIZER_MULTIPLIER)
55
+
56
+    # Gradient checkpointing trades time for memory at runtime but
57
+    # doesn't change checkpoint size, so it's not in the formula here.
58
+    _ = plan  # suppress unused warning; kept as a hook for future plan-driven heuristics
59
+
60
+    return adapter_bytes + optimizer_bytes + _LOG_RESERVE
61
+
62
+
63
+def preflight_disk(
64
+    store_root: Path,
65
+    spec: BaseModelSpec,
66
+    plan: TrainingPlan,
67
+    *,
68
+    safety: float = 1.5,
69
+) -> None:
70
+    """Raise `DiskSpaceError` if the store FS can't fit a checkpoint + margin.
71
+
72
+    `safety` defaults to 1.5× — the trainer can get unlucky with
73
+    intermediate buffers, and a hard-fail at step 9/10 of a multi-hour
74
+    run is painful.
75
+    """
76
+    if safety <= 0:
77
+        raise ValueError(f"safety must be > 0, got {safety!r}")
78
+
79
+    estimate = estimate_checkpoint_bytes(spec, plan)
80
+    required = int(estimate * safety)
81
+
82
+    usage = shutil.disk_usage(store_root)
83
+    if usage.free < required:
84
+        raise DiskSpaceError(required_bytes=required, free_bytes=usage.free)