tenseleyflow/sway / ac3f2ff

Browse files

sway(backends): custom entry-point dispatch with protocol validation

Authored by espadonne
SHA
ac3f2ff40004e551199966391b1876b9e5f67e20
Parents
c666cbf
Tree
d64360e

2 changed files

StatusFile+-
M src/dlm_sway/backends/__init__.py 53 5
M tests/unit/test_backend_registry.py 82 3
src/dlm_sway/backends/__init__.pymodified
@@ -55,13 +55,61 @@ def build(base_spec: ModelSpec, *, adapter_path: Path | None = None) -> Differen
5555
         )
5656
 
5757
     if base_spec.kind == "custom":
58
-        raise BackendNotAvailableError(
59
-            "custom",
60
-            extra="hf",
61
-            hint="Custom backend entry-point dispatch shipping in a later milestone.",
62
-        )
58
+        return _load_custom(base_spec, effective_adapter)
6359
 
6460
     raise SpecValidationError(f"unknown backend kind: {base_spec.kind!r}")
6561
 
6662
 
63
+def _load_custom(base_spec: ModelSpec, adapter: Path | None) -> DifferentialBackend:
64
+    """Dispatch to a user-supplied backend via ``entry_point='pkg.mod:Name'``.
65
+
66
+    The imported class is instantiated as ``Cls(base_spec=..., adapter_path=...)``
67
+    — the same signature as :class:`dlm_sway.backends.hf.HuggingFaceDifferentialBackend`
68
+    so authors can model their implementation on the built-in. The
69
+    result is runtime-checked against :class:`DifferentialBackend` so
70
+    protocol violations fail at construction, not deep inside a probe.
71
+    """
72
+    from dlm_sway.core.scoring import DifferentialBackend as DiffBackend
73
+
74
+    entry = base_spec.entry_point
75
+    if not entry:
76
+        raise SpecValidationError(
77
+            "kind='custom' requires an entry_point of the form 'pkg.module:ClassName'"
78
+        )
79
+    if ":" not in entry:
80
+        raise SpecValidationError(f"entry_point must be 'pkg.module:ClassName', got {entry!r}")
81
+    module_path, _, class_name = entry.partition(":")
82
+    if not module_path or not class_name:
83
+        raise SpecValidationError(f"entry_point must be 'pkg.module:ClassName', got {entry!r}")
84
+
85
+    import importlib
86
+
87
+    try:
88
+        module = importlib.import_module(module_path)
89
+    except ImportError as exc:
90
+        raise SpecValidationError(
91
+            f"custom backend: cannot import module {module_path!r}: {exc}"
92
+        ) from exc
93
+    cls = getattr(module, class_name, None)
94
+    if cls is None:
95
+        raise SpecValidationError(
96
+            f"custom backend: module {module_path!r} has no attribute {class_name!r}"
97
+        )
98
+
99
+    try:
100
+        instance = cls(base_spec=base_spec, adapter_path=adapter)
101
+    except TypeError as exc:
102
+        raise SpecValidationError(
103
+            f"custom backend {entry!r} constructor signature mismatch: {exc}. "
104
+            "Expected Cls(base_spec: ModelSpec, adapter_path: Path | None)"
105
+        ) from exc
106
+
107
+    if not isinstance(instance, DiffBackend):
108
+        raise SpecValidationError(
109
+            f"custom backend {entry!r} does not satisfy DifferentialBackend "
110
+            "(needs as_base() and as_finetuned() context managers)"
111
+        )
112
+    return instance
113
+
114
+
67115
 __all__ = ["build"]
tests/unit/test_backend_registry.pymodified
@@ -31,13 +31,92 @@ class TestRegistry:
3131
             build(ModelSpec(base="x", kind="mlx", adapter=Path("/tmp/a")))
3232
         assert exc_info.value.backend == "mlx"
3333
 
34
-    def test_custom_not_yet_available(self) -> None:
35
-        with pytest.raises(BackendNotAvailableError):
34
+    def test_custom_requires_entry_point(self) -> None:
35
+        with pytest.raises(SpecValidationError, match="entry_point"):
36
+            build(ModelSpec(base="x", kind="custom", adapter=Path("/tmp/a")))
37
+
38
+    def test_custom_validates_entry_point_shape(self) -> None:
39
+        with pytest.raises(SpecValidationError, match="pkg.module:ClassName"):
40
+            build(
41
+                ModelSpec(
42
+                    base="x",
43
+                    kind="custom",
44
+                    entry_point="not_a_valid_entry_point",
45
+                    adapter=Path("/tmp/a"),
46
+                )
47
+            )
48
+
49
+    def test_custom_rejects_unimportable_module(self) -> None:
50
+        with pytest.raises(SpecValidationError, match="cannot import"):
3651
             build(
3752
                 ModelSpec(
3853
                     base="x",
3954
                     kind="custom",
40
-                    entry_point="pkg:Backend",
55
+                    entry_point="nonexistent_pkg_xyz:Backend",
4156
                     adapter=Path("/tmp/a"),
4257
                 )
4358
             )
59
+
60
+    def test_custom_rejects_missing_class(self) -> None:
61
+        with pytest.raises(SpecValidationError, match="has no attribute"):
62
+            build(
63
+                ModelSpec(
64
+                    base="x",
65
+                    kind="custom",
66
+                    entry_point="dlm_sway:NoSuchClass",
67
+                    adapter=Path("/tmp/a"),
68
+                )
69
+            )
70
+
71
+    def test_custom_rejects_non_differential_class(self) -> None:
72
+        # A class that accepts the canonical constructor args but doesn't
73
+        # implement the protocol.
74
+        import sys
75
+        import types
76
+
77
+        class _Bad:
78
+            def __init__(self, base_spec, adapter_path):  # type: ignore[no-untyped-def]
79
+                del base_spec, adapter_path
80
+
81
+        mod = types.ModuleType("_sway_bad_mod")
82
+        mod.Bad = _Bad  # type: ignore[attr-defined]
83
+        sys.modules["_sway_bad_mod"] = mod
84
+
85
+        with pytest.raises(SpecValidationError, match="DifferentialBackend"):
86
+            build(
87
+                ModelSpec(
88
+                    base="x",
89
+                    kind="custom",
90
+                    entry_point="_sway_bad_mod:Bad",
91
+                    adapter=Path("/tmp/a"),
92
+                )
93
+            )
94
+
95
+    def test_custom_dispatches_to_valid_backend(self) -> None:
96
+        # Use the dummy backend via a custom entry point. The dummy class's
97
+        # __init__ takes different args, so we write a thin adapter class.
98
+        from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses
99
+
100
+        class _AdapterBackend(DummyDifferentialBackend):
101
+            def __init__(self, base_spec, adapter_path):  # type: ignore[no-untyped-def]
102
+                super().__init__(base=DummyResponses(), ft=DummyResponses())
103
+
104
+        # Register on a throwaway module we can find by name.
105
+        import sys
106
+        import types
107
+
108
+        mod = types.ModuleType("_sway_custom_test_mod")
109
+        mod.AdapterBackend = _AdapterBackend  # type: ignore[attr-defined]
110
+        sys.modules["_sway_custom_test_mod"] = mod
111
+
112
+        backend = build(
113
+            ModelSpec(
114
+                base="x",
115
+                kind="custom",
116
+                entry_point="_sway_custom_test_mod:AdapterBackend",
117
+                adapter=Path("/tmp/a"),
118
+            )
119
+        )
120
+        from dlm_sway.core.scoring import DifferentialBackend
121
+
122
+        assert isinstance(backend, DifferentialBackend)