Python · 6714 bytes Raw Blame History
1 """Scoring backends: HuggingFace (``hf``), MLX (``mlx``), dummy, custom.
2
3 Backends are constructed from a :class:`~dlm_sway.core.model.ModelSpec`
4 via :func:`build`. Heavy backends (HF, MLX) import their framework only
5 on construction so ``import dlm_sway`` stays cheap for users who only
6 touch the dummy backend or the spec loader.
7 """
8
9 from __future__ import annotations
10
11 from pathlib import Path
12 from typing import TYPE_CHECKING
13
14 from dlm_sway.core.errors import SpecValidationError
15 from dlm_sway.core.model import ModelSpec
16
17 if TYPE_CHECKING:
18 from dlm_sway.core.scoring import DifferentialBackend
19
20
21 def build(base_spec: ModelSpec, *, adapter_path: Path | None = None) -> DifferentialBackend:
22 """Materialize a differential backend from a model spec.
23
24 The adapter path typically comes from ``ft.adapter`` in the spec —
25 it's lifted to a keyword here so the same function can be used for
26 "differential" (base + adapter on one loaded model) or future
27 split-load paths.
28 """
29 effective_adapter = adapter_path if adapter_path is not None else base_spec.adapter
30
31 if base_spec.kind == "dummy":
32 # Dummy backend isn't really about the spec — it's for tests
33 # that pre-populate responses. Surface a loud error if someone
34 # tries to build it through the normal path.
35 raise SpecValidationError(
36 "kind='dummy' backends must be constructed directly via "
37 "DummyDifferentialBackend(base=..., ft=...); they cannot be "
38 "materialized from a ModelSpec."
39 )
40
41 if base_spec.kind == "hf":
42 if effective_adapter is None:
43 raise SpecValidationError(
44 "hf backend requires an adapter path (set `adapter:` on the ft model)"
45 )
46 from dlm_sway.backends.hf import HuggingFaceDifferentialBackend
47
48 return HuggingFaceDifferentialBackend(base_spec=base_spec, adapter_path=effective_adapter)
49
50 if base_spec.kind == "mlx":
51 if effective_adapter is None:
52 raise SpecValidationError(
53 "mlx backend requires an adapter path (set `adapter:` on the ft model; "
54 "must be an MLX .npz adapter — use dlm's peft→mlx converter if needed)"
55 )
56 from dlm_sway.backends.mlx import MLXDifferentialBackend
57
58 return MLXDifferentialBackend(base_spec=base_spec, adapter_path=effective_adapter)
59
60 if base_spec.kind == "api":
61 # An API backend represents ONE endpoint — it has no local
62 # model toggle. The idiomatic path is
63 # ``defaults.differential: false`` in sway.yaml, which routes
64 # through :func:`build_two_separate`; that calls this dispatch
65 # twice (once per side) and wraps the results in
66 # :class:`TwoModelDifferential`. The ``adapter:`` field is
67 # ignored — "adapter" for an API backend is a *different
68 # model name*, not a local file path.
69 if base_spec.endpoint is None:
70 raise SpecValidationError(
71 "api backend requires `endpoint:` on the ModelSpec "
72 "(the base URL of the /v1/completions server)"
73 )
74 from dlm_sway.backends.api import ApiScoringBackend
75
76 return ApiScoringBackend(
77 base_url=base_spec.endpoint,
78 model_name=base_spec.base,
79 )
80
81 if base_spec.kind == "custom":
82 return _load_custom(base_spec, effective_adapter)
83
84 raise SpecValidationError(f"unknown backend kind: {base_spec.kind!r}")
85
86
87 def _load_custom(base_spec: ModelSpec, adapter: Path | None) -> DifferentialBackend:
88 """Dispatch to a user-supplied backend via ``entry_point='pkg.mod:Name'``.
89
90 The imported class is instantiated as ``Cls(base_spec=..., adapter_path=...)``
91 — the same signature as :class:`dlm_sway.backends.hf.HuggingFaceDifferentialBackend`
92 so authors can model their implementation on the built-in. The
93 result is runtime-checked against :class:`DifferentialBackend` and
94 the optional :class:`NullCalibratedBackend` /
95 :class:`ScalableDifferentialBackend` so protocol violations fail
96 at construction, not deep inside a probe (B20). The set of
97 satisfied protocols is recorded on the instance as
98 ``__sway_protocols__: tuple[str, ...]`` for the report's
99 backend-info section.
100 """
101 from dlm_sway.core.scoring import (
102 DifferentialBackend as DiffBackend,
103 )
104 from dlm_sway.core.scoring import (
105 NullCalibratedBackend,
106 ScalableDifferentialBackend,
107 )
108
109 entry = base_spec.entry_point
110 if not entry:
111 raise SpecValidationError(
112 "kind='custom' requires an entry_point of the form 'pkg.module:ClassName'"
113 )
114 if ":" not in entry:
115 raise SpecValidationError(f"entry_point must be 'pkg.module:ClassName', got {entry!r}")
116 module_path, _, class_name = entry.partition(":")
117 if not module_path or not class_name:
118 raise SpecValidationError(f"entry_point must be 'pkg.module:ClassName', got {entry!r}")
119
120 import importlib
121
122 try:
123 module = importlib.import_module(module_path)
124 except ImportError as exc:
125 raise SpecValidationError(
126 f"custom backend: cannot import module {module_path!r}: {exc}"
127 ) from exc
128 cls = getattr(module, class_name, None)
129 if cls is None:
130 raise SpecValidationError(
131 f"custom backend: module {module_path!r} has no attribute {class_name!r}"
132 )
133
134 try:
135 instance = cls(base_spec=base_spec, adapter_path=adapter)
136 except TypeError as exc:
137 raise SpecValidationError(
138 f"custom backend {entry!r} constructor signature mismatch: {exc}. "
139 "Expected Cls(base_spec: ModelSpec, adapter_path: Path | None)"
140 ) from exc
141
142 if not isinstance(instance, DiffBackend):
143 raise SpecValidationError(
144 f"custom backend {entry!r} does not satisfy DifferentialBackend "
145 "(needs as_base() and as_finetuned() context managers)"
146 )
147
148 # B20: probe optional protocols and record them so the runner /
149 # report can show which downstream features are available without
150 # repeated isinstance checks.
151 satisfied: list[str] = ["DifferentialBackend"]
152 if isinstance(instance, NullCalibratedBackend):
153 satisfied.append("NullCalibratedBackend")
154 if isinstance(instance, ScalableDifferentialBackend):
155 satisfied.append("ScalableDifferentialBackend")
156 instance.__sway_protocols__ = tuple(satisfied) # type: ignore[attr-defined]
157 return instance
158
159
160 from dlm_sway.backends.two_model import TwoModelDifferential, build_two_separate # noqa: E402
161
162 __all__ = ["TwoModelDifferential", "build", "build_two_separate"]