@@ -131,3 +131,69 @@ class TestRegistry: |
| 131 | 131 | from dlm_sway.core.scoring import DifferentialBackend |
| 132 | 132 | |
| 133 | 133 | assert isinstance(backend, DifferentialBackend) |
| 134 | + |
| 135 | + |
| 136 | +class TestCustomProtocolStamp: |
| 137 | + """B20: ``__sway_protocols__`` records the optional protocols satisfied.""" |
| 138 | + |
| 139 | + def test_stamps_null_calibrated_when_satisfied(self) -> None: |
| 140 | + from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses |
| 141 | + |
| 142 | + class _AdapterBackend(DummyDifferentialBackend): |
| 143 | + def __init__(self, base_spec, adapter_path): # type: ignore[no-untyped-def] |
| 144 | + super().__init__(base=DummyResponses(), ft=DummyResponses()) |
| 145 | + |
| 146 | + import sys |
| 147 | + import types |
| 148 | + |
| 149 | + mod = types.ModuleType("_sway_protostamp_mod") |
| 150 | + mod.AdapterBackend = _AdapterBackend # type: ignore[attr-defined] |
| 151 | + sys.modules["_sway_protostamp_mod"] = mod |
| 152 | + |
| 153 | + backend = build( |
| 154 | + ModelSpec( |
| 155 | + base="x", |
| 156 | + kind="custom", |
| 157 | + entry_point="_sway_protostamp_mod:AdapterBackend", |
| 158 | + adapter=Path("/tmp/a"), |
| 159 | + ) |
| 160 | + ) |
| 161 | + protocols = getattr(backend, "__sway_protocols__", ()) |
| 162 | + assert "DifferentialBackend" in protocols |
| 163 | + # DummyDifferentialBackend satisfies both optional protocols. |
| 164 | + assert "NullCalibratedBackend" in protocols |
| 165 | + assert "ScalableDifferentialBackend" in protocols |
| 166 | + |
| 167 | + def test_minimal_diff_backend_only_records_diff(self) -> None: |
| 168 | + """A backend that only satisfies DifferentialBackend gets stamped accordingly.""" |
| 169 | + from contextlib import contextmanager |
| 170 | + |
| 171 | + class _MinimalBackend: |
| 172 | + def __init__(self, base_spec, adapter_path): # type: ignore[no-untyped-def] |
| 173 | + del base_spec, adapter_path |
| 174 | + |
| 175 | + @contextmanager |
| 176 | + def as_base(self): # noqa: ANN201 |
| 177 | + yield self |
| 178 | + |
| 179 | + @contextmanager |
| 180 | + def as_finetuned(self): # noqa: ANN201 |
| 181 | + yield self |
| 182 | + |
| 183 | + import sys |
| 184 | + import types |
| 185 | + |
| 186 | + mod = types.ModuleType("_sway_minimal_mod") |
| 187 | + mod.Minimal = _MinimalBackend # type: ignore[attr-defined] |
| 188 | + sys.modules["_sway_minimal_mod"] = mod |
| 189 | + |
| 190 | + backend = build( |
| 191 | + ModelSpec( |
| 192 | + base="x", |
| 193 | + kind="custom", |
| 194 | + entry_point="_sway_minimal_mod:Minimal", |
| 195 | + adapter=Path("/tmp/a"), |
| 196 | + ) |
| 197 | + ) |
| 198 | + protocols = getattr(backend, "__sway_protocols__", ()) |
| 199 | + assert protocols == ("DifferentialBackend",) |