Python · 6801 bytes Raw Blame History
1 """Map optimizer-state param-ids → adapter layer indices (S25, P2).
2
3 The ``gradient_ghost`` probe needs to attribute per-param Adam stats
4 (``exp_avg_sq`` magnitudes) back to layer indices for the per-layer
5 reporting the sprint DoD requires.
6
7 ## Why not name-level attribution?
8
9 The ``optimizer_state_dict.state`` dict in dlm's ``training_state.pt``
10 keys parameters by integer id (0..N-1) in the order PEFT registered
11 them with the optimizer. The full param-id → name mapping requires
12 walking ``model.named_parameters()`` — but the probe's whole pitch
13 is "no model load." So we punt on naming individual modules and
14 group at the layer level instead.
15
16 ## Why layer-grouping is safe
17
18 PEFT registers trainable LoRA params in **layer-major order**: for
19 each transformer block, all its LoRA factors first, then the next
20 block. Verified against dlm-trained SmolLM2-135M:
21
22 - 30 transformer layers
23 - 4 target modules per layer (q_proj, k_proj, v_proj, o_proj)
24 - 2 LoRA factors per module (A, B)
25 - Total: 30 × 4 × 2 = 240 params, in layer-major order
26
27 Even when the probe doesn't know *which* module within a layer a
28 given param belongs to, grouping by layer index gives meaningful
29 per-layer health reporting — and matches how a user thinks about
30 "my adapter's layer 5 isn't learning."
31 """
32
33 from __future__ import annotations
34
35 import re
36 from dataclasses import dataclass
37 from pathlib import Path
38
39 from dlm_sway.core.errors import SwayError
40
41
42 class ParamMappingError(SwayError):
43 """Raised when ``adapter_model.safetensors`` can't be read or
44 the param-count doesn't divide evenly across layers."""
45
46
47 @dataclass(frozen=True, slots=True)
48 class LayerGrouping:
49 """Result of grouping per-param ids by transformer-layer index.
50
51 Attributes
52 ----------
53 layer_indices:
54 Sorted, deduped list of transformer-layer indices the
55 adapter touches (e.g. ``[0, 1, 2, ..., 29]`` for SmolLM2).
56 params_per_layer:
57 How many trainable params each layer has (constant across
58 layers — PEFT applies the same target_modules everywhere).
59 layer_of:
60 Function mapping a param-id to its transformer layer index,
61 or ``None`` when the param-id falls outside the layered
62 param space (e.g. embedding overrides via
63 ``modules_to_save``).
64 """
65
66 layer_indices: tuple[int, ...]
67 params_per_layer: int
68 layer_of: dict[int, int]
69
70 @property
71 def num_layers(self) -> int:
72 return len(self.layer_indices)
73
74
75 _LAYER_INDEX_RE = re.compile(r"\.layers\.(\d+)\.")
76
77
78 def map_param_ids_to_layers(adapter_dir: Path, num_params: int) -> LayerGrouping:
79 """Group ``optimizer_state_dict.state`` param-ids by layer index.
80
81 Parameters
82 ----------
83 adapter_dir:
84 Adapter directory containing ``adapter_model.safetensors``.
85 We don't load the safetensors weights themselves — just
86 introspect the keys to count params per layer.
87 num_params:
88 Number of param-ids in the optimizer state. Used as a
89 cross-check: the safetensors key count should match.
90
91 Returns
92 -------
93 LayerGrouping
94 Per-layer grouping ready for the probe to attribute stats.
95
96 Raises
97 ------
98 ParamMappingError
99 ``adapter_model.safetensors`` is missing OR the param count
100 doesn't divide evenly by the layer count (unexpected adapter
101 shape — PEFT's per-layer module set wasn't uniform).
102 """
103 safetensors_path = adapter_dir / "adapter_model.safetensors"
104 if not safetensors_path.exists():
105 raise ParamMappingError(
106 f"adapter_model.safetensors missing at {safetensors_path} — "
107 "can't recover layer indices for per-layer reporting"
108 )
109
110 try:
111 from safetensors import safe_open # noqa: PLC0415 — lazy
112 except ImportError as exc:
113 raise ParamMappingError(
114 "safetensors not installed — gradient_ghost needs it for "
115 "per-layer attribution. Install with: pip install 'dlm-sway[hf]'"
116 ) from exc
117
118 # Use ``safe_open`` so we never materialize the tensors — we only
119 # need the key list. Saves the 7+ MB read on a typical adapter.
120 # safetensors ships no py.typed in some versions and a py.typed
121 # marker in others — silence both states (untyped-call when the
122 # marker is missing, unused-ignore when it's present).
123 with safe_open( # type: ignore[no-untyped-call,unused-ignore]
124 str(safetensors_path), framework="numpy", device="cpu"
125 ) as fh:
126 keys = list(fh.keys())
127 if not keys:
128 raise ParamMappingError(f"{safetensors_path}: no keys found")
129
130 # Extract layer index from each key. Keys without a layer index
131 # (modules_to_save full-tensor overrides) are tracked separately
132 # and excluded from per-layer attribution.
133 keys_by_layer: dict[int, int] = {}
134 for k in keys:
135 match = _LAYER_INDEX_RE.search(k)
136 if match is None:
137 continue
138 idx = int(match.group(1))
139 keys_by_layer[idx] = keys_by_layer.get(idx, 0) + 1
140
141 if not keys_by_layer:
142 raise ParamMappingError(
143 f"{safetensors_path}: no keys carry a layer index — adapter "
144 "may target only embeddings / lm_head"
145 )
146
147 layer_indices = tuple(sorted(keys_by_layer))
148 counts = set(keys_by_layer.values())
149 if len(counts) != 1:
150 # Heterogeneous per-layer param counts. The simple "K per
151 # layer in order" mapping breaks; refuse rather than guess.
152 raise ParamMappingError(
153 f"{safetensors_path}: per-layer param count is heterogeneous "
154 f"({sorted(set(keys_by_layer.values()))}) — gradient_ghost "
155 "can't safely attribute params to layers in this configuration"
156 )
157 params_per_layer = next(iter(counts))
158
159 # Cross-check vs optimizer-state count. They should match: each
160 # safetensors key has one optimizer-state entry. If we see fewer
161 # optimizer params than safetensors keys, the user pruned state;
162 # if we see more, modules_to_save params have their own optimizer
163 # entries (not handled here).
164 expected = params_per_layer * len(layer_indices)
165 if num_params < expected:
166 raise ParamMappingError(
167 f"optimizer state has {num_params} params but safetensors "
168 f"has {expected} layered weight keys — adapter / state mismatch"
169 )
170
171 layer_of: dict[int, int] = {}
172 for ordinal in range(expected):
173 layer_of[ordinal] = layer_indices[ordinal // params_per_layer]
174 # Param-ids beyond the layered range (modules_to_save) get None
175 # via dict.get(...) — the probe filters them out.
176
177 return LayerGrouping(
178 layer_indices=layer_indices,
179 params_per_layer=params_per_layer,
180 layer_of=layer_of,
181 )