@@ -44,26 +44,39 @@ if TYPE_CHECKING: |
| 44 | 44 | def _resolve_layer(model: Any, layer_index: int) -> Any: |
| 45 | 45 | """Locate the residual-stream module at `layer_index`. |
| 46 | 46 | |
| 47 | | - HF decoder-only models all expose `model.model.layers[i]` (Llama, |
| 48 | | - Qwen, SmolLM, Phi — the canonical path). We try that first, then |
| 49 | | - fall back to `model.layers[i]` for models that expose layers at |
| 50 | | - the top level. `layer_index` can be negative (`-1` is the last |
| 51 | | - layer, matching list-indexing semantics). |
| 52 | | - |
| 53 | | - Raises `ControlApplyError` if neither attribute exists or the |
| 54 | | - index is out of bounds. |
| 47 | + HF decoder-only models expose `model.model.layers[i]` (Llama, |
| 48 | + Qwen, SmolLM, Phi — the canonical path). PEFT wraps that under a |
| 49 | + `base_model.model.model.layers[i]` chain (`PeftModel.base_model` |
| 50 | + is a `LoraModel` whose `.model` is the HF model). Rather than |
| 51 | + hard-code two shapes, we walk down repeated `base_model` / `model` |
| 52 | + hops — the first node that exposes a `layers` attribute wins. |
| 53 | + `layer_index` can be negative (`-1` is the last layer). |
| 54 | + |
| 55 | + Raises `ControlApplyError` when the walker can't find a `layers` |
| 56 | + attribute anywhere down the chain or when the index is out of |
| 57 | + bounds. |
| 55 | 58 | """ |
| 56 | 59 | layers = None |
| 57 | | - inner = getattr(model, "model", None) |
| 58 | | - if inner is not None: |
| 59 | | - layers = getattr(inner, "layers", None) |
| 60 | | - if layers is None: |
| 61 | | - layers = getattr(model, "layers", None) |
| 60 | + node: Any = model |
| 61 | + # Cap the walk so a pathological graph can't spin forever; real |
| 62 | + # wrappers are at most 2-3 deep (PEFT adds 2, a rare Accelerate |
| 63 | + # wrapper adds 1). |
| 64 | + for _ in range(6): |
| 65 | + layers = getattr(node, "layers", None) |
| 66 | + if layers is not None: |
| 67 | + break |
| 68 | + next_node = getattr(node, "base_model", None) or getattr(node, "model", None) |
| 69 | + if next_node is None or next_node is node: |
| 70 | + break |
| 71 | + node = next_node |
| 72 | + |
| 62 | 73 | if layers is None: |
| 63 | 74 | raise ControlApplyError( |
| 64 | | - "model has neither `model.layers` nor `layers` — don't know " |
| 65 | | - "where to attach the forward hook. Pass a HF decoder-only " |
| 66 | | - "model, or open an issue with the model class for wiring." |
| 75 | + "model exposes no `layers` attribute along the " |
| 76 | + "`base_model` / `model` chain — don't know where to " |
| 77 | + "attach the forward hook. Pass a HF decoder-only model " |
| 78 | + "(or a PEFT-wrapped one), or file an issue with the model " |
| 79 | + "class for wiring." |
| 67 | 80 | ) |
| 68 | 81 | try: |
| 69 | 82 | return layers[layer_index] |