sway(backends): custom entry-point dispatch with protocol validation
- SHA
ac3f2ff40004e551199966391b1876b9e5f67e20- Parents
-
c666cbf - Tree
d64360e
ac3f2ff
ac3f2ff40004e551199966391b1876b9e5f67e20c666cbf
d64360e| Status | File | + | - |
|---|---|---|---|
| M |
src/dlm_sway/backends/__init__.py
|
53 | 5 |
| M |
tests/unit/test_backend_registry.py
|
82 | 3 |
src/dlm_sway/backends/__init__.pymodified@@ -55,13 +55,61 @@ def build(base_spec: ModelSpec, *, adapter_path: Path | None = None) -> Differen | ||
| 55 | 55 | ) |
| 56 | 56 | |
| 57 | 57 | if base_spec.kind == "custom": |
| 58 | - raise BackendNotAvailableError( | |
| 59 | - "custom", | |
| 60 | - extra="hf", | |
| 61 | - hint="Custom backend entry-point dispatch shipping in a later milestone.", | |
| 62 | - ) | |
| 58 | + return _load_custom(base_spec, effective_adapter) | |
| 63 | 59 | |
| 64 | 60 | raise SpecValidationError(f"unknown backend kind: {base_spec.kind!r}") |
| 65 | 61 | |
| 66 | 62 | |
| 63 | +def _load_custom(base_spec: ModelSpec, adapter: Path | None) -> DifferentialBackend: | |
| 64 | + """Dispatch to a user-supplied backend via ``entry_point='pkg.mod:Name'``. | |
| 65 | + | |
| 66 | + The imported class is instantiated as ``Cls(base_spec=..., adapter_path=...)`` | |
| 67 | + — the same signature as :class:`dlm_sway.backends.hf.HuggingFaceDifferentialBackend` | |
| 68 | + so authors can model their implementation on the built-in. The | |
| 69 | + result is runtime-checked against :class:`DifferentialBackend` so | |
| 70 | + protocol violations fail at construction, not deep inside a probe. | |
| 71 | + """ | |
| 72 | + from dlm_sway.core.scoring import DifferentialBackend as DiffBackend | |
| 73 | + | |
| 74 | + entry = base_spec.entry_point | |
| 75 | + if not entry: | |
| 76 | + raise SpecValidationError( | |
| 77 | + "kind='custom' requires an entry_point of the form 'pkg.module:ClassName'" | |
| 78 | + ) | |
| 79 | + if ":" not in entry: | |
| 80 | + raise SpecValidationError(f"entry_point must be 'pkg.module:ClassName', got {entry!r}") | |
| 81 | + module_path, _, class_name = entry.partition(":") | |
| 82 | + if not module_path or not class_name: | |
| 83 | + raise SpecValidationError(f"entry_point must be 'pkg.module:ClassName', got {entry!r}") | |
| 84 | + | |
| 85 | + import importlib | |
| 86 | + | |
| 87 | + try: | |
| 88 | + module = importlib.import_module(module_path) | |
| 89 | + except ImportError as exc: | |
| 90 | + raise SpecValidationError( | |
| 91 | + f"custom backend: cannot import module {module_path!r}: {exc}" | |
| 92 | + ) from exc | |
| 93 | + cls = getattr(module, class_name, None) | |
| 94 | + if cls is None: | |
| 95 | + raise SpecValidationError( | |
| 96 | + f"custom backend: module {module_path!r} has no attribute {class_name!r}" | |
| 97 | + ) | |
| 98 | + | |
| 99 | + try: | |
| 100 | + instance = cls(base_spec=base_spec, adapter_path=adapter) | |
| 101 | + except TypeError as exc: | |
| 102 | + raise SpecValidationError( | |
| 103 | + f"custom backend {entry!r} constructor signature mismatch: {exc}. " | |
| 104 | + "Expected Cls(base_spec: ModelSpec, adapter_path: Path | None)" | |
| 105 | + ) from exc | |
| 106 | + | |
| 107 | + if not isinstance(instance, DiffBackend): | |
| 108 | + raise SpecValidationError( | |
| 109 | + f"custom backend {entry!r} does not satisfy DifferentialBackend " | |
| 110 | + "(needs as_base() and as_finetuned() context managers)" | |
| 111 | + ) | |
| 112 | + return instance | |
| 113 | + | |
| 114 | + | |
| 67 | 115 | __all__ = ["build"] |
tests/unit/test_backend_registry.pymodified@@ -31,13 +31,92 @@ class TestRegistry: | ||
| 31 | 31 | build(ModelSpec(base="x", kind="mlx", adapter=Path("/tmp/a"))) |
| 32 | 32 | assert exc_info.value.backend == "mlx" |
| 33 | 33 | |
| 34 | - def test_custom_not_yet_available(self) -> None: | |
| 35 | - with pytest.raises(BackendNotAvailableError): | |
| 34 | + def test_custom_requires_entry_point(self) -> None: | |
| 35 | + with pytest.raises(SpecValidationError, match="entry_point"): | |
| 36 | + build(ModelSpec(base="x", kind="custom", adapter=Path("/tmp/a"))) | |
| 37 | + | |
| 38 | + def test_custom_validates_entry_point_shape(self) -> None: | |
| 39 | + with pytest.raises(SpecValidationError, match="pkg.module:ClassName"): | |
| 40 | + build( | |
| 41 | + ModelSpec( | |
| 42 | + base="x", | |
| 43 | + kind="custom", | |
| 44 | + entry_point="not_a_valid_entry_point", | |
| 45 | + adapter=Path("/tmp/a"), | |
| 46 | + ) | |
| 47 | + ) | |
| 48 | + | |
| 49 | + def test_custom_rejects_unimportable_module(self) -> None: | |
| 50 | + with pytest.raises(SpecValidationError, match="cannot import"): | |
| 36 | 51 | build( |
| 37 | 52 | ModelSpec( |
| 38 | 53 | base="x", |
| 39 | 54 | kind="custom", |
| 40 | - entry_point="pkg:Backend", | |
| 55 | + entry_point="nonexistent_pkg_xyz:Backend", | |
| 41 | 56 | adapter=Path("/tmp/a"), |
| 42 | 57 | ) |
| 43 | 58 | ) |
| 59 | + | |
| 60 | + def test_custom_rejects_missing_class(self) -> None: | |
| 61 | + with pytest.raises(SpecValidationError, match="has no attribute"): | |
| 62 | + build( | |
| 63 | + ModelSpec( | |
| 64 | + base="x", | |
| 65 | + kind="custom", | |
| 66 | + entry_point="dlm_sway:NoSuchClass", | |
| 67 | + adapter=Path("/tmp/a"), | |
| 68 | + ) | |
| 69 | + ) | |
| 70 | + | |
| 71 | + def test_custom_rejects_non_differential_class(self) -> None: | |
| 72 | + # A class that accepts the canonical constructor args but doesn't | |
| 73 | + # implement the protocol. | |
| 74 | + import sys | |
| 75 | + import types | |
| 76 | + | |
| 77 | + class _Bad: | |
| 78 | + def __init__(self, base_spec, adapter_path): # type: ignore[no-untyped-def] | |
| 79 | + del base_spec, adapter_path | |
| 80 | + | |
| 81 | + mod = types.ModuleType("_sway_bad_mod") | |
| 82 | + mod.Bad = _Bad # type: ignore[attr-defined] | |
| 83 | + sys.modules["_sway_bad_mod"] = mod | |
| 84 | + | |
| 85 | + with pytest.raises(SpecValidationError, match="DifferentialBackend"): | |
| 86 | + build( | |
| 87 | + ModelSpec( | |
| 88 | + base="x", | |
| 89 | + kind="custom", | |
| 90 | + entry_point="_sway_bad_mod:Bad", | |
| 91 | + adapter=Path("/tmp/a"), | |
| 92 | + ) | |
| 93 | + ) | |
| 94 | + | |
| 95 | + def test_custom_dispatches_to_valid_backend(self) -> None: | |
| 96 | + # Use the dummy backend via a custom entry point. The dummy class's | |
| 97 | + # __init__ takes different args, so we write a thin adapter class. | |
| 98 | + from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses | |
| 99 | + | |
| 100 | + class _AdapterBackend(DummyDifferentialBackend): | |
| 101 | + def __init__(self, base_spec, adapter_path): # type: ignore[no-untyped-def] | |
| 102 | + super().__init__(base=DummyResponses(), ft=DummyResponses()) | |
| 103 | + | |
| 104 | + # Register on a throwaway module we can find by name. | |
| 105 | + import sys | |
| 106 | + import types | |
| 107 | + | |
| 108 | + mod = types.ModuleType("_sway_custom_test_mod") | |
| 109 | + mod.AdapterBackend = _AdapterBackend # type: ignore[attr-defined] | |
| 110 | + sys.modules["_sway_custom_test_mod"] = mod | |
| 111 | + | |
| 112 | + backend = build( | |
| 113 | + ModelSpec( | |
| 114 | + base="x", | |
| 115 | + kind="custom", | |
| 116 | + entry_point="_sway_custom_test_mod:AdapterBackend", | |
| 117 | + adapter=Path("/tmp/a"), | |
| 118 | + ) | |
| 119 | + ) | |
| 120 | + from dlm_sway.core.scoring import DifferentialBackend | |
| 121 | + | |
| 122 | + assert isinstance(backend, DifferentialBackend) | |