tenseleyflow/sway / 93c3098

Browse files

sway(backends): as_null_adapter with in-place lora weight randomization

Authored by espadonne
SHA
93c3098bfa4d7d0ea9c768f6c77f421b2c6c3943
Parents
fadb0ae
Tree
4d5b21e

2 changed files

StatusFile+-
M src/dlm_sway/backends/dummy.py 39 0
M src/dlm_sway/backends/hf.py 37 0
src/dlm_sway/backends/dummy.pymodified
@@ -117,6 +117,37 @@ class _DummyView:
117
         )
117
         )
118
 
118
 
119
 
119
 
120
+class _NullView(_DummyView):
121
+    """A dummy view that perturbs the base distribution with seeded noise.
122
+
123
+    Used by :meth:`DummyDifferentialBackend.as_null_adapter`. The
124
+    perturbation is small (matches an ``init_scale=0.02`` adapter) so
125
+    the null-vs-base divergence stays well below real-adapter territory
126
+    in probe tests.
127
+    """
128
+
129
+    def __init__(self, base_responses: DummyResponses, seed: int, init_scale: float) -> None:
130
+        super().__init__("base", base_responses)
131
+        self._seed = seed
132
+        self._init_scale = init_scale
133
+
134
+    def next_token_dist(self, prompt: str, *, top_k: int = 256) -> TokenDist:
135
+        base_dist = super().next_token_dist(prompt, top_k=top_k)
136
+        rng = np.random.default_rng(self._seed + hash(prompt) % 1_000_003)
137
+        noise = rng.normal(0.0, self._init_scale, size=base_dist.logprobs.shape).astype(np.float32)
138
+        new_lp = base_dist.logprobs + noise
139
+        # Re-normalize (within the top-k slice) so a valid distribution comes back.
140
+        max_lp = new_lp.max()
141
+        new_probs = np.exp(new_lp - max_lp)
142
+        new_probs /= new_probs.sum()
143
+        return TokenDist(
144
+            token_ids=base_dist.token_ids,
145
+            logprobs=np.log(new_probs).astype(np.float32),
146
+            vocab_size=base_dist.vocab_size,
147
+            tail_logprob=base_dist.tail_logprob,
148
+        )
149
+
150
+
120
 class _InterpolatedView(_DummyView):
151
 class _InterpolatedView(_DummyView):
121
     """A dummy view where logits/dists are a lam-blend of base and ft.
152
     """A dummy view where logits/dists are a lam-blend of base and ft.
122
 
153
 
@@ -206,6 +237,14 @@ class DummyDifferentialBackend:
206
         finally:
237
         finally:
207
             self._exit()
238
             self._exit()
208
 
239
 
240
+    @contextmanager
241
+    def as_null_adapter(self, seed: int, *, init_scale: float = 0.02) -> Iterator[_DummyView]:
242
+        self._enter(f"null({seed})")
243
+        try:
244
+            yield _NullView(self._base_r, seed=seed, init_scale=init_scale)
245
+        finally:
246
+            self._exit()
247
+
209
     def _enter(self, mode: str) -> None:
248
     def _enter(self, mode: str) -> None:
210
         if self._active is not None:
249
         if self._active is not None:
211
             raise RuntimeError(
250
             raise RuntimeError(
src/dlm_sway/backends/hf.pymodified
@@ -301,6 +301,43 @@ class HuggingFaceDifferentialBackend:
301
                 module.scaling[key] = original  # type: ignore[attr-defined]
301
                 module.scaling[key] = original  # type: ignore[attr-defined]
302
             self._exit()
302
             self._exit()
303
 
303
 
304
+    @contextmanager
305
+    def as_null_adapter(self, seed: int, *, init_scale: float = 0.02) -> Iterator[_HFView]:
306
+        """Temporarily replace every LoRA ``A``/``B`` tensor with random noise.
307
+
308
+        Same rank, alpha, and target modules as the real adapter — only
309
+        the weights differ. This is the denominator in every z-score
310
+        path: "how much signal does structural noise produce?"
311
+
312
+        Implementation walks the PEFT module tree for ``lora_A``/``lora_B``
313
+        parameters, saves a clone of each current value, overwrites in
314
+        place with a zero-mean Gaussian at ``init_scale``, and restores
315
+        on exit (including on exception).
316
+        """
317
+        import torch
318
+
319
+        self._enter(f"null({seed})")
320
+        gen = torch.Generator(device="cpu").manual_seed(int(seed))
321
+        saved: list[tuple[torch.nn.Parameter, torch.Tensor]] = []
322
+        try:
323
+            for pname, param in self._peft_model.named_parameters():
324
+                if not any(key in pname for key in ("lora_A", "lora_B")):
325
+                    continue
326
+                saved.append((param, param.detach().clone()))
327
+                with torch.no_grad():
328
+                    noise = torch.randn(
329
+                        *param.shape,
330
+                        generator=gen,
331
+                        dtype=torch.float32,
332
+                    ).to(dtype=param.dtype, device=param.device)
333
+                    param.copy_(noise * init_scale)
334
+            yield self._make_view(f"null_{seed}")
335
+        finally:
336
+            with torch.no_grad():
337
+                for param, original in saved:
338
+                    param.copy_(original)
339
+            self._exit()
340
+
304
     def close(self) -> None:
341
     def close(self) -> None:
305
         """Release GPU memory. Safe to call more than once."""
342
         """Release GPU memory. Safe to call more than once."""
306
         if getattr(self, "_peft_model", None) is not None:
343
         if getattr(self, "_peft_model", None) is not None: