tenseleyflow/documentlanguagemodel / f6b3599

Browse files

fix(control): resolve layers through PEFT base_model wrapper chain

Authored by espadonne
SHA
f6b359982e04a7b7703be8a0a5b799f2a66de657
Parents
ef3d707
Tree
0364da1

1 changed file

StatusFile+-
M src/dlm/control/apply.py 29 16
src/dlm/control/apply.pymodified
@@ -44,26 +44,39 @@ if TYPE_CHECKING:
4444
 def _resolve_layer(model: Any, layer_index: int) -> Any:
4545
     """Locate the residual-stream module at `layer_index`.
4646
 
47
-    HF decoder-only models all expose `model.model.layers[i]` (Llama,
48
-    Qwen, SmolLM, Phi — the canonical path). We try that first, then
49
-    fall back to `model.layers[i]` for models that expose layers at
50
-    the top level. `layer_index` can be negative (`-1` is the last
51
-    layer, matching list-indexing semantics).
52
-
53
-    Raises `ControlApplyError` if neither attribute exists or the
54
-    index is out of bounds.
47
+    HF decoder-only models expose `model.model.layers[i]` (Llama,
48
+    Qwen, SmolLM, Phi — the canonical path). PEFT wraps that under a
49
+    `base_model.model.model.layers[i]` chain (`PeftModel.base_model`
50
+    is a `LoraModel` whose `.model` is the HF model). Rather than
51
+    hard-code two shapes, we walk down repeated `base_model` / `model`
52
+    hops — the first node that exposes a `layers` attribute wins.
53
+    `layer_index` can be negative (`-1` is the last layer).
54
+
55
+    Raises `ControlApplyError` when the walker can't find a `layers`
56
+    attribute anywhere down the chain or when the index is out of
57
+    bounds.
5558
     """
5659
     layers = None
57
-    inner = getattr(model, "model", None)
58
-    if inner is not None:
59
-        layers = getattr(inner, "layers", None)
60
-    if layers is None:
61
-        layers = getattr(model, "layers", None)
60
+    node: Any = model
61
+    # Cap the walk so a pathological graph can't spin forever; real
62
+    # wrappers are at most 2-3 deep (PEFT adds 2, a rare Accelerate
63
+    # wrapper adds 1).
64
+    for _ in range(6):
65
+        layers = getattr(node, "layers", None)
66
+        if layers is not None:
67
+            break
68
+        next_node = getattr(node, "base_model", None) or getattr(node, "model", None)
69
+        if next_node is None or next_node is node:
70
+            break
71
+        node = next_node
72
+
6273
     if layers is None:
6374
         raise ControlApplyError(
64
-            "model has neither `model.layers` nor `layers` — don't know "
65
-            "where to attach the forward hook. Pass a HF decoder-only "
66
-            "model, or open an issue with the model class for wiring."
75
+            "model exposes no `layers` attribute along the "
76
+            "`base_model` / `model` chain — don't know where to "
77
+            "attach the forward hook. Pass a HF decoder-only model "
78
+            "(or a PEFT-wrapped one), or file an issue with the model "
79
+            "class for wiring."
6780
         )
6881
     try:
6982
         return layers[layer_index]