sway(backends): custom entry-point dispatch with protocol validation
- SHA
ac3f2ff40004e551199966391b1876b9e5f67e20- Parents
-
c666cbf - Tree
d64360e
ac3f2ff
ac3f2ff40004e551199966391b1876b9e5f67e20c666cbf
d64360e| Status | File | + | - |
|---|---|---|---|
| M |
src/dlm_sway/backends/__init__.py
|
53 | 5 |
| M |
tests/unit/test_backend_registry.py
|
82 | 3 |
src/dlm_sway/backends/__init__.pymodified@@ -55,13 +55,61 @@ def build(base_spec: ModelSpec, *, adapter_path: Path | None = None) -> Differen | |||
| 55 | ) | 55 | ) |
| 56 | 56 | ||
| 57 | if base_spec.kind == "custom": | 57 | if base_spec.kind == "custom": |
| 58 | - raise BackendNotAvailableError( | 58 | + return _load_custom(base_spec, effective_adapter) |
| 59 | - "custom", | ||
| 60 | - extra="hf", | ||
| 61 | - hint="Custom backend entry-point dispatch shipping in a later milestone.", | ||
| 62 | - ) | ||
| 63 | 59 | ||
| 64 | raise SpecValidationError(f"unknown backend kind: {base_spec.kind!r}") | 60 | raise SpecValidationError(f"unknown backend kind: {base_spec.kind!r}") |
| 65 | 61 | ||
| 66 | 62 | ||
| 63 | +def _load_custom(base_spec: ModelSpec, adapter: Path | None) -> DifferentialBackend: | ||
| 64 | + """Dispatch to a user-supplied backend via ``entry_point='pkg.mod:Name'``. | ||
| 65 | + | ||
| 66 | + The imported class is instantiated as ``Cls(base_spec=..., adapter_path=...)`` | ||
| 67 | + — the same signature as :class:`dlm_sway.backends.hf.HuggingFaceDifferentialBackend` | ||
| 68 | + so authors can model their implementation on the built-in. The | ||
| 69 | + result is runtime-checked against :class:`DifferentialBackend` so | ||
| 70 | + protocol violations fail at construction, not deep inside a probe. | ||
| 71 | + """ | ||
| 72 | + from dlm_sway.core.scoring import DifferentialBackend as DiffBackend | ||
| 73 | + | ||
| 74 | + entry = base_spec.entry_point | ||
| 75 | + if not entry: | ||
| 76 | + raise SpecValidationError( | ||
| 77 | + "kind='custom' requires an entry_point of the form 'pkg.module:ClassName'" | ||
| 78 | + ) | ||
| 79 | + if ":" not in entry: | ||
| 80 | + raise SpecValidationError(f"entry_point must be 'pkg.module:ClassName', got {entry!r}") | ||
| 81 | + module_path, _, class_name = entry.partition(":") | ||
| 82 | + if not module_path or not class_name: | ||
| 83 | + raise SpecValidationError(f"entry_point must be 'pkg.module:ClassName', got {entry!r}") | ||
| 84 | + | ||
| 85 | + import importlib | ||
| 86 | + | ||
| 87 | + try: | ||
| 88 | + module = importlib.import_module(module_path) | ||
| 89 | + except ImportError as exc: | ||
| 90 | + raise SpecValidationError( | ||
| 91 | + f"custom backend: cannot import module {module_path!r}: {exc}" | ||
| 92 | + ) from exc | ||
| 93 | + cls = getattr(module, class_name, None) | ||
| 94 | + if cls is None: | ||
| 95 | + raise SpecValidationError( | ||
| 96 | + f"custom backend: module {module_path!r} has no attribute {class_name!r}" | ||
| 97 | + ) | ||
| 98 | + | ||
| 99 | + try: | ||
| 100 | + instance = cls(base_spec=base_spec, adapter_path=adapter) | ||
| 101 | + except TypeError as exc: | ||
| 102 | + raise SpecValidationError( | ||
| 103 | + f"custom backend {entry!r} constructor signature mismatch: {exc}. " | ||
| 104 | + "Expected Cls(base_spec: ModelSpec, adapter_path: Path | None)" | ||
| 105 | + ) from exc | ||
| 106 | + | ||
| 107 | + if not isinstance(instance, DiffBackend): | ||
| 108 | + raise SpecValidationError( | ||
| 109 | + f"custom backend {entry!r} does not satisfy DifferentialBackend " | ||
| 110 | + "(needs as_base() and as_finetuned() context managers)" | ||
| 111 | + ) | ||
| 112 | + return instance | ||
| 113 | + | ||
| 114 | + | ||
| 67 | __all__ = ["build"] | 115 | __all__ = ["build"] |
tests/unit/test_backend_registry.pymodified@@ -31,13 +31,92 @@ class TestRegistry: | |||
| 31 | build(ModelSpec(base="x", kind="mlx", adapter=Path("/tmp/a"))) | 31 | build(ModelSpec(base="x", kind="mlx", adapter=Path("/tmp/a"))) |
| 32 | assert exc_info.value.backend == "mlx" | 32 | assert exc_info.value.backend == "mlx" |
| 33 | 33 | ||
| 34 | - def test_custom_not_yet_available(self) -> None: | 34 | + def test_custom_requires_entry_point(self) -> None: |
| 35 | - with pytest.raises(BackendNotAvailableError): | 35 | + with pytest.raises(SpecValidationError, match="entry_point"): |
| 36 | + build(ModelSpec(base="x", kind="custom", adapter=Path("/tmp/a"))) | ||
| 37 | + | ||
| 38 | + def test_custom_validates_entry_point_shape(self) -> None: | ||
| 39 | + with pytest.raises(SpecValidationError, match="pkg.module:ClassName"): | ||
| 40 | + build( | ||
| 41 | + ModelSpec( | ||
| 42 | + base="x", | ||
| 43 | + kind="custom", | ||
| 44 | + entry_point="not_a_valid_entry_point", | ||
| 45 | + adapter=Path("/tmp/a"), | ||
| 46 | + ) | ||
| 47 | + ) | ||
| 48 | + | ||
| 49 | + def test_custom_rejects_unimportable_module(self) -> None: | ||
| 50 | + with pytest.raises(SpecValidationError, match="cannot import"): | ||
| 36 | build( | 51 | build( |
| 37 | ModelSpec( | 52 | ModelSpec( |
| 38 | base="x", | 53 | base="x", |
| 39 | kind="custom", | 54 | kind="custom", |
| 40 | - entry_point="pkg:Backend", | 55 | + entry_point="nonexistent_pkg_xyz:Backend", |
| 41 | adapter=Path("/tmp/a"), | 56 | adapter=Path("/tmp/a"), |
| 42 | ) | 57 | ) |
| 43 | ) | 58 | ) |
| 59 | + | ||
| 60 | + def test_custom_rejects_missing_class(self) -> None: | ||
| 61 | + with pytest.raises(SpecValidationError, match="has no attribute"): | ||
| 62 | + build( | ||
| 63 | + ModelSpec( | ||
| 64 | + base="x", | ||
| 65 | + kind="custom", | ||
| 66 | + entry_point="dlm_sway:NoSuchClass", | ||
| 67 | + adapter=Path("/tmp/a"), | ||
| 68 | + ) | ||
| 69 | + ) | ||
| 70 | + | ||
| 71 | + def test_custom_rejects_non_differential_class(self) -> None: | ||
| 72 | + # A class that accepts the canonical constructor args but doesn't | ||
| 73 | + # implement the protocol. | ||
| 74 | + import sys | ||
| 75 | + import types | ||
| 76 | + | ||
| 77 | + class _Bad: | ||
| 78 | + def __init__(self, base_spec, adapter_path): # type: ignore[no-untyped-def] | ||
| 79 | + del base_spec, adapter_path | ||
| 80 | + | ||
| 81 | + mod = types.ModuleType("_sway_bad_mod") | ||
| 82 | + mod.Bad = _Bad # type: ignore[attr-defined] | ||
| 83 | + sys.modules["_sway_bad_mod"] = mod | ||
| 84 | + | ||
| 85 | + with pytest.raises(SpecValidationError, match="DifferentialBackend"): | ||
| 86 | + build( | ||
| 87 | + ModelSpec( | ||
| 88 | + base="x", | ||
| 89 | + kind="custom", | ||
| 90 | + entry_point="_sway_bad_mod:Bad", | ||
| 91 | + adapter=Path("/tmp/a"), | ||
| 92 | + ) | ||
| 93 | + ) | ||
| 94 | + | ||
| 95 | + def test_custom_dispatches_to_valid_backend(self) -> None: | ||
| 96 | + # Use the dummy backend via a custom entry point. The dummy class's | ||
| 97 | + # __init__ takes different args, so we write a thin adapter class. | ||
| 98 | + from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses | ||
| 99 | + | ||
| 100 | + class _AdapterBackend(DummyDifferentialBackend): | ||
| 101 | + def __init__(self, base_spec, adapter_path): # type: ignore[no-untyped-def] | ||
| 102 | + super().__init__(base=DummyResponses(), ft=DummyResponses()) | ||
| 103 | + | ||
| 104 | + # Register on a throwaway module we can find by name. | ||
| 105 | + import sys | ||
| 106 | + import types | ||
| 107 | + | ||
| 108 | + mod = types.ModuleType("_sway_custom_test_mod") | ||
| 109 | + mod.AdapterBackend = _AdapterBackend # type: ignore[attr-defined] | ||
| 110 | + sys.modules["_sway_custom_test_mod"] = mod | ||
| 111 | + | ||
| 112 | + backend = build( | ||
| 113 | + ModelSpec( | ||
| 114 | + base="x", | ||
| 115 | + kind="custom", | ||
| 116 | + entry_point="_sway_custom_test_mod:AdapterBackend", | ||
| 117 | + adapter=Path("/tmp/a"), | ||
| 118 | + ) | ||
| 119 | + ) | ||
| 120 | + from dlm_sway.core.scoring import DifferentialBackend | ||
| 121 | + | ||
| 122 | + assert isinstance(backend, DifferentialBackend) | ||