sway(backends): as_null_adapter with in-place lora weight randomization
- SHA
93c3098bfa4d7d0ea9c768f6c77f421b2c6c3943- Parents
-
fadb0ae - Tree
4d5b21e
93c3098
93c3098bfa4d7d0ea9c768f6c77f421b2c6c3943fadb0ae
4d5b21e| Status | File | + | - |
|---|---|---|---|
| M |
src/dlm_sway/backends/dummy.py
|
39 | 0 |
| M |
src/dlm_sway/backends/hf.py
|
37 | 0 |
src/dlm_sway/backends/dummy.pymodified@@ -117,6 +117,37 @@ class _DummyView: | |||
| 117 | ) | 117 | ) |
| 118 | 118 | ||
| 119 | 119 | ||
| 120 | +class _NullView(_DummyView): | ||
| 121 | + """A dummy view that perturbs the base distribution with seeded noise. | ||
| 122 | + | ||
| 123 | + Used by :meth:`DummyDifferentialBackend.as_null_adapter`. The | ||
| 124 | + perturbation is small (matches an ``init_scale=0.02`` adapter) so | ||
| 125 | + the null-vs-base divergence stays well below real-adapter territory | ||
| 126 | + in probe tests. | ||
| 127 | + """ | ||
| 128 | + | ||
| 129 | + def __init__(self, base_responses: DummyResponses, seed: int, init_scale: float) -> None: | ||
| 130 | + super().__init__("base", base_responses) | ||
| 131 | + self._seed = seed | ||
| 132 | + self._init_scale = init_scale | ||
| 133 | + | ||
| 134 | + def next_token_dist(self, prompt: str, *, top_k: int = 256) -> TokenDist: | ||
| 135 | + base_dist = super().next_token_dist(prompt, top_k=top_k) | ||
| 136 | + rng = np.random.default_rng(self._seed + hash(prompt) % 1_000_003) | ||
| 137 | + noise = rng.normal(0.0, self._init_scale, size=base_dist.logprobs.shape).astype(np.float32) | ||
| 138 | + new_lp = base_dist.logprobs + noise | ||
| 139 | + # Re-normalize (within the top-k slice) so a valid distribution comes back. | ||
| 140 | + max_lp = new_lp.max() | ||
| 141 | + new_probs = np.exp(new_lp - max_lp) | ||
| 142 | + new_probs /= new_probs.sum() | ||
| 143 | + return TokenDist( | ||
| 144 | + token_ids=base_dist.token_ids, | ||
| 145 | + logprobs=np.log(new_probs).astype(np.float32), | ||
| 146 | + vocab_size=base_dist.vocab_size, | ||
| 147 | + tail_logprob=base_dist.tail_logprob, | ||
| 148 | + ) | ||
| 149 | + | ||
| 150 | + | ||
| 120 | class _InterpolatedView(_DummyView): | 151 | class _InterpolatedView(_DummyView): |
| 121 | """A dummy view where logits/dists are a lam-blend of base and ft. | 152 | """A dummy view where logits/dists are a lam-blend of base and ft. |
| 122 | 153 | ||
@@ -206,6 +237,14 @@ class DummyDifferentialBackend: | |||
| 206 | finally: | 237 | finally: |
| 207 | self._exit() | 238 | self._exit() |
| 208 | 239 | ||
| 240 | + @contextmanager | ||
| 241 | + def as_null_adapter(self, seed: int, *, init_scale: float = 0.02) -> Iterator[_DummyView]: | ||
| 242 | + self._enter(f"null({seed})") | ||
| 243 | + try: | ||
| 244 | + yield _NullView(self._base_r, seed=seed, init_scale=init_scale) | ||
| 245 | + finally: | ||
| 246 | + self._exit() | ||
| 247 | + | ||
| 209 | def _enter(self, mode: str) -> None: | 248 | def _enter(self, mode: str) -> None: |
| 210 | if self._active is not None: | 249 | if self._active is not None: |
| 211 | raise RuntimeError( | 250 | raise RuntimeError( |
src/dlm_sway/backends/hf.pymodified@@ -301,6 +301,43 @@ class HuggingFaceDifferentialBackend: | |||
| 301 | module.scaling[key] = original # type: ignore[attr-defined] | 301 | module.scaling[key] = original # type: ignore[attr-defined] |
| 302 | self._exit() | 302 | self._exit() |
| 303 | 303 | ||
| 304 | + @contextmanager | ||
| 305 | + def as_null_adapter(self, seed: int, *, init_scale: float = 0.02) -> Iterator[_HFView]: | ||
| 306 | + """Temporarily replace every LoRA ``A``/``B`` tensor with random noise. | ||
| 307 | + | ||
| 308 | + Same rank, alpha, and target modules as the real adapter — only | ||
| 309 | + the weights differ. This is the denominator in every z-score | ||
| 310 | + path: "how much signal does structural noise produce?" | ||
| 311 | + | ||
| 312 | + Implementation walks the PEFT module tree for ``lora_A``/``lora_B`` | ||
| 313 | + parameters, saves a clone of each current value, overwrites in | ||
| 314 | + place with a zero-mean Gaussian at ``init_scale``, and restores | ||
| 315 | + on exit (including on exception). | ||
| 316 | + """ | ||
| 317 | + import torch | ||
| 318 | + | ||
| 319 | + self._enter(f"null({seed})") | ||
| 320 | + gen = torch.Generator(device="cpu").manual_seed(int(seed)) | ||
| 321 | + saved: list[tuple[torch.nn.Parameter, torch.Tensor]] = [] | ||
| 322 | + try: | ||
| 323 | + for pname, param in self._peft_model.named_parameters(): | ||
| 324 | + if not any(key in pname for key in ("lora_A", "lora_B")): | ||
| 325 | + continue | ||
| 326 | + saved.append((param, param.detach().clone())) | ||
| 327 | + with torch.no_grad(): | ||
| 328 | + noise = torch.randn( | ||
| 329 | + *param.shape, | ||
| 330 | + generator=gen, | ||
| 331 | + dtype=torch.float32, | ||
| 332 | + ).to(dtype=param.dtype, device=param.device) | ||
| 333 | + param.copy_(noise * init_scale) | ||
| 334 | + yield self._make_view(f"null_{seed}") | ||
| 335 | + finally: | ||
| 336 | + with torch.no_grad(): | ||
| 337 | + for param, original in saved: | ||
| 338 | + param.copy_(original) | ||
| 339 | + self._exit() | ||
| 340 | + | ||
| 304 | def close(self) -> None: | 341 | def close(self) -> None: |
| 305 | """Release GPU memory. Safe to call more than once.""" | 342 | """Release GPU memory. Safe to call more than once.""" |
| 306 | if getattr(self, "_peft_model", None) is not None: | 343 | if getattr(self, "_peft_model", None) is not None: |