tenseleyflow/sway / 4cde232

Browse files

sway(core): ScalableDifferentialBackend protocol for adapter-scale access

Authored by espadonne
SHA
4cde23284d23c934fbfd04ba8b1ec62873b5badc
Parents
cc8e52f
Tree
28e9323

2 changed files

StatusFile+-
M src/dlm_sway/__init__.py 2 0
M src/dlm_sway/core/scoring.py 21 0
src/dlm_sway/__init__.pymodified
@@ -13,6 +13,7 @@ from dlm_sway.core.result import ProbeResult, SuiteResult, SwayScore, Verdict
1313
 from dlm_sway.core.scoring import (
1414
     DifferentialBackend,
1515
     RollingLogprob,
16
+    ScalableDifferentialBackend,
1617
     ScoringBackend,
1718
     TokenDist,
1819
 )
@@ -26,6 +27,7 @@ __all__ = [
2627
     "ProbeError",
2728
     "ProbeResult",
2829
     "RollingLogprob",
30
+    "ScalableDifferentialBackend",
2931
     "ScoringBackend",
3032
     "SpecValidationError",
3133
     "SuiteResult",
src/dlm_sway/core/scoring.pymodified
@@ -140,6 +140,27 @@ class DifferentialBackend(Protocol):
140140
     def as_finetuned(self) -> AbstractContextManager[_ScoringModel]: ...
141141
 
142142
 
143
+@runtime_checkable
144
+class ScalableDifferentialBackend(DifferentialBackend, Protocol):
145
+    """A differential backend that can also scale the LoRA additive term.
146
+
147
+    LoRA applies ``W + (alpha/r) · B @ A`` to a base weight matrix. This
148
+    protocol exposes a context manager that temporarily multiplies that
149
+    additive term by ``lam`` for everything inside the ``with`` block.
150
+
151
+    ``lam = 0.0`` is equivalent to :meth:`as_base`.
152
+    ``lam = 1.0`` is equivalent to :meth:`as_finetuned`.
153
+    ``lam = 1.25`` overshoots — useful for N2 AdapterAblation's
154
+    response-curve measurement.
155
+
156
+    Only the HF backend ships an implementation in v0.1. Probes that
157
+    need scaling check via ``isinstance(backend, ScalableDifferentialBackend)``
158
+    at runtime and SKIP gracefully when unavailable.
159
+    """
160
+
161
+    def as_scaled_adapter(self, lam: float) -> AbstractContextManager[_ScoringModel]: ...
162
+
163
+
143164
 # Helper Protocol for type-checking the yielded context object: it
144165
 # must satisfy both Model and ScoringBackend. mypy doesn't support
145166
 # intersection types, so we spell it out explicitly.