tenseleyflow/sway / 7f806df

Browse files

sway(backends): HF as_scaled_adapter mutates PEFT LoraLayer scaling dict

Authored by espadonne
SHA
7f806df19062ed82ca820ac3f344a20213aa4c4b
Parents
bbaff71
Tree
f1d12ad

1 changed file

StatusFile+-
M src/dlm_sway/backends/hf.py 32 0
src/dlm_sway/backends/hf.pymodified
@@ -269,6 +269,38 @@ class HuggingFaceDifferentialBackend:
269269
         finally:
270270
             self._exit()
271271
 
272
+    @contextmanager
273
+    def as_scaled_adapter(self, lam: float) -> Iterator[_HFView]:
274
+        """Temporarily multiply every LoRA layer's scaling factor by ``lam``.
275
+
276
+        Works by walking the PEFT module tree and mutating each
277
+        ``LoraLayer.scaling[adapter_name]`` in place. The original
278
+        scalings are restored when the context exits — or when an
279
+        exception propagates, to keep the model in a sane state.
280
+        """
281
+        self._enter(f"scaled({lam})")
282
+        saved: list[tuple[object, str, float]] = []
283
+        try:
284
+            import peft  # noqa: PLC0415 — already a hard dep of this backend
285
+
286
+            lora_cls = getattr(peft.tuners.lora, "LoraLayer", None)
287
+            if lora_cls is None:
288
+                raise RuntimeError("peft.tuners.lora.LoraLayer not found; check peft>=0.13 pin")
289
+            for module in self._peft_model.modules():
290
+                if not isinstance(module, lora_cls):
291
+                    continue
292
+                scaling = getattr(module, "scaling", None)
293
+                if not isinstance(scaling, dict):
294
+                    continue
295
+                for key, original in scaling.items():
296
+                    saved.append((module, key, float(original)))
297
+                    scaling[key] = float(original) * lam
298
+            yield self._make_view(f"scaled_{lam:.2f}")
299
+        finally:
300
+            for module, key, original in saved:
301
+                module.scaling[key] = original  # type: ignore[attr-defined]
302
+            self._exit()
303
+
272304
     def close(self) -> None:
273305
         """Release GPU memory. Safe to call more than once."""
274306
         if getattr(self, "_peft_model", None) is not None: