tenseleyflow/sway / beea653

Browse files

probes/_param_id_mapping: layer-major param-id → layer-index grouping (S25 P2)

Authored by mfwolffe <wolffemf@dukes.jmu.edu>
SHA
beea653ee7178868dbb0fe3c4e3b1eaa537c31b5
Parents
f6e20a1
Tree
d79dee2

1 changed file

StatusFile+-
A src/dlm_sway/probes/_param_id_mapping.py 176 0
src/dlm_sway/probes/_param_id_mapping.pyadded
@@ -0,0 +1,176 @@
1
+"""Map optimizer-state param-ids → adapter layer indices (S25, P2).
2
+
3
+The ``gradient_ghost`` probe needs to attribute per-param Adam stats
4
+(``exp_avg_sq`` magnitudes) back to layer indices for the per-layer
5
+reporting the sprint DoD requires.
6
+
7
+## Why not name-level attribution?
8
+
9
+The ``optimizer_state_dict.state`` dict in dlm's ``training_state.pt``
10
+keys parameters by integer id (0..N-1) in the order PEFT registered
11
+them with the optimizer. The full param-id → name mapping requires
12
+walking ``model.named_parameters()`` — but the probe's whole pitch
13
+is "no model load." So we punt on naming individual modules and
14
+group at the layer level instead.
15
+
16
+## Why layer-grouping is safe
17
+
18
+PEFT registers trainable LoRA params in **layer-major order**: for
19
+each transformer block, all its LoRA factors first, then the next
20
+block. Verified against dlm-trained SmolLM2-135M:
21
+
22
+- 30 transformer layers
23
+- 4 target modules per layer (q_proj, k_proj, v_proj, o_proj)
24
+- 2 LoRA factors per module (A, B)
25
+- Total: 30 × 4 × 2 = 240 params, in layer-major order
26
+
27
+Even when the probe doesn't know *which* module within a layer a
28
+given param belongs to, grouping by layer index gives meaningful
29
+per-layer health reporting — and matches how a user thinks about
30
+"my adapter's layer 5 isn't learning."
31
+"""
32
+
33
+from __future__ import annotations
34
+
35
+import re
36
+from dataclasses import dataclass
37
+from pathlib import Path
38
+
39
+from dlm_sway.core.errors import SwayError
40
+
41
+
42
+class ParamMappingError(SwayError):
43
+    """Raised when ``adapter_model.safetensors`` can't be read or
44
+    the param-count doesn't divide evenly across layers."""
45
+
46
+
47
+@dataclass(frozen=True, slots=True)
48
+class LayerGrouping:
49
+    """Result of grouping per-param ids by transformer-layer index.
50
+
51
+    Attributes
52
+    ----------
53
+    layer_indices:
54
+        Sorted, deduped list of transformer-layer indices the
55
+        adapter touches (e.g. ``[0, 1, 2, ..., 29]`` for SmolLM2).
56
+    params_per_layer:
57
+        How many trainable params each layer has (constant across
58
+        layers — PEFT applies the same target_modules everywhere).
59
+    layer_of:
60
+        Function mapping a param-id to its transformer layer index,
61
+        or ``None`` when the param-id falls outside the layered
62
+        param space (e.g. embedding overrides via
63
+        ``modules_to_save``).
64
+    """
65
+
66
+    layer_indices: tuple[int, ...]
67
+    params_per_layer: int
68
+    layer_of: dict[int, int]
69
+
70
+    @property
71
+    def num_layers(self) -> int:
72
+        return len(self.layer_indices)
73
+
74
+
75
+_LAYER_INDEX_RE = re.compile(r"\.layers\.(\d+)\.")
76
+
77
+
78
+def map_param_ids_to_layers(adapter_dir: Path, num_params: int) -> LayerGrouping:
79
+    """Group ``optimizer_state_dict.state`` param-ids by layer index.
80
+
81
+    Parameters
82
+    ----------
83
+    adapter_dir:
84
+        Adapter directory containing ``adapter_model.safetensors``.
85
+        We don't load the safetensors weights themselves — just
86
+        introspect the keys to count params per layer.
87
+    num_params:
88
+        Number of param-ids in the optimizer state. Used as a
89
+        cross-check: the safetensors key count should match.
90
+
91
+    Returns
92
+    -------
93
+    LayerGrouping
94
+        Per-layer grouping ready for the probe to attribute stats.
95
+
96
+    Raises
97
+    ------
98
+    ParamMappingError
99
+        ``adapter_model.safetensors`` is missing OR the param count
100
+        doesn't divide evenly by the layer count (unexpected adapter
101
+        shape — PEFT's per-layer module set wasn't uniform).
102
+    """
103
+    safetensors_path = adapter_dir / "adapter_model.safetensors"
104
+    if not safetensors_path.exists():
105
+        raise ParamMappingError(
106
+            f"adapter_model.safetensors missing at {safetensors_path} — "
107
+            "can't recover layer indices for per-layer reporting"
108
+        )
109
+
110
+    try:
111
+        from safetensors import safe_open  # noqa: PLC0415 — lazy
112
+    except ImportError as exc:
113
+        raise ParamMappingError(
114
+            "safetensors not installed — gradient_ghost needs it for "
115
+            "per-layer attribution. Install with: pip install 'dlm-sway[hf]'"
116
+        ) from exc
117
+
118
+    # Use ``safe_open`` so we never materialize the tensors — we only
119
+    # need the key list. Saves the 7+ MB read on a typical adapter.
120
+    with safe_open(str(safetensors_path), framework="numpy", device="cpu") as fh:
121
+        keys = list(fh.keys())
122
+    if not keys:
123
+        raise ParamMappingError(f"{safetensors_path}: no keys found")
124
+
125
+    # Extract layer index from each key. Keys without a layer index
126
+    # (modules_to_save full-tensor overrides) are tracked separately
127
+    # and excluded from per-layer attribution.
128
+    keys_by_layer: dict[int, int] = {}
129
+    for k in keys:
130
+        match = _LAYER_INDEX_RE.search(k)
131
+        if match is None:
132
+            continue
133
+        idx = int(match.group(1))
134
+        keys_by_layer[idx] = keys_by_layer.get(idx, 0) + 1
135
+
136
+    if not keys_by_layer:
137
+        raise ParamMappingError(
138
+            f"{safetensors_path}: no keys carry a layer index — adapter "
139
+            "may target only embeddings / lm_head"
140
+        )
141
+
142
+    layer_indices = tuple(sorted(keys_by_layer))
143
+    counts = set(keys_by_layer.values())
144
+    if len(counts) != 1:
145
+        # Heterogeneous per-layer param counts. The simple "K per
146
+        # layer in order" mapping breaks; refuse rather than guess.
147
+        raise ParamMappingError(
148
+            f"{safetensors_path}: per-layer param count is heterogeneous "
149
+            f"({sorted(set(keys_by_layer.values()))}) — gradient_ghost "
150
+            "can't safely attribute params to layers in this configuration"
151
+        )
152
+    params_per_layer = next(iter(counts))
153
+
154
+    # Cross-check vs optimizer-state count. They should match: each
155
+    # safetensors key has one optimizer-state entry. If we see fewer
156
+    # optimizer params than safetensors keys, the user pruned state;
157
+    # if we see more, modules_to_save params have their own optimizer
158
+    # entries (not handled here).
159
+    expected = params_per_layer * len(layer_indices)
160
+    if num_params < expected:
161
+        raise ParamMappingError(
162
+            f"optimizer state has {num_params} params but safetensors "
163
+            f"has {expected} layered weight keys — adapter / state mismatch"
164
+        )
165
+
166
+    layer_of: dict[int, int] = {}
167
+    for ordinal in range(expected):
168
+        layer_of[ordinal] = layer_indices[ordinal // params_per_layer]
169
+    # Param-ids beyond the layered range (modules_to_save) get None
170
+    # via dict.get(...) — the probe filters them out.
171
+
172
+    return LayerGrouping(
173
+        layer_indices=layer_indices,
174
+        params_per_layer=params_per_layer,
175
+        layer_of=layer_of,
176
+    )