| 1 | """Scoring backends: HuggingFace (``hf``), MLX (``mlx``), dummy, custom. |
| 2 | |
| 3 | Backends are constructed from a :class:`~dlm_sway.core.model.ModelSpec` |
| 4 | via :func:`build`. Heavy backends (HF, MLX) import their framework only |
| 5 | on construction so ``import dlm_sway`` stays cheap for users who only |
| 6 | touch the dummy backend or the spec loader. |
| 7 | """ |
| 8 | |
| 9 | from __future__ import annotations |
| 10 | |
| 11 | from pathlib import Path |
| 12 | from typing import TYPE_CHECKING |
| 13 | |
| 14 | from dlm_sway.core.errors import SpecValidationError |
| 15 | from dlm_sway.core.model import ModelSpec |
| 16 | |
| 17 | if TYPE_CHECKING: |
| 18 | from dlm_sway.core.scoring import DifferentialBackend |
| 19 | |
| 20 | |
| 21 | def build(base_spec: ModelSpec, *, adapter_path: Path | None = None) -> DifferentialBackend: |
| 22 | """Materialize a differential backend from a model spec. |
| 23 | |
| 24 | The adapter path typically comes from ``ft.adapter`` in the spec — |
| 25 | it's lifted to a keyword here so the same function can be used for |
| 26 | "differential" (base + adapter on one loaded model) or future |
| 27 | split-load paths. |
| 28 | """ |
| 29 | effective_adapter = adapter_path if adapter_path is not None else base_spec.adapter |
| 30 | |
| 31 | if base_spec.kind == "dummy": |
| 32 | # Dummy backend isn't really about the spec — it's for tests |
| 33 | # that pre-populate responses. Surface a loud error if someone |
| 34 | # tries to build it through the normal path. |
| 35 | raise SpecValidationError( |
| 36 | "kind='dummy' backends must be constructed directly via " |
| 37 | "DummyDifferentialBackend(base=..., ft=...); they cannot be " |
| 38 | "materialized from a ModelSpec." |
| 39 | ) |
| 40 | |
| 41 | if base_spec.kind == "hf": |
| 42 | if effective_adapter is None: |
| 43 | raise SpecValidationError( |
| 44 | "hf backend requires an adapter path (set `adapter:` on the ft model)" |
| 45 | ) |
| 46 | from dlm_sway.backends.hf import HuggingFaceDifferentialBackend |
| 47 | |
| 48 | return HuggingFaceDifferentialBackend(base_spec=base_spec, adapter_path=effective_adapter) |
| 49 | |
| 50 | if base_spec.kind == "mlx": |
| 51 | if effective_adapter is None: |
| 52 | raise SpecValidationError( |
| 53 | "mlx backend requires an adapter path (set `adapter:` on the ft model; " |
| 54 | "must be an MLX .npz adapter — use dlm's peft→mlx converter if needed)" |
| 55 | ) |
| 56 | from dlm_sway.backends.mlx import MLXDifferentialBackend |
| 57 | |
| 58 | return MLXDifferentialBackend(base_spec=base_spec, adapter_path=effective_adapter) |
| 59 | |
| 60 | if base_spec.kind == "api": |
| 61 | # An API backend represents ONE endpoint — it has no local |
| 62 | # model toggle. The idiomatic path is |
| 63 | # ``defaults.differential: false`` in sway.yaml, which routes |
| 64 | # through :func:`build_two_separate`; that calls this dispatch |
| 65 | # twice (once per side) and wraps the results in |
| 66 | # :class:`TwoModelDifferential`. The ``adapter:`` field is |
| 67 | # ignored — "adapter" for an API backend is a *different |
| 68 | # model name*, not a local file path. |
| 69 | if base_spec.endpoint is None: |
| 70 | raise SpecValidationError( |
| 71 | "api backend requires `endpoint:` on the ModelSpec " |
| 72 | "(the base URL of the /v1/completions server)" |
| 73 | ) |
| 74 | from dlm_sway.backends.api import ApiScoringBackend |
| 75 | |
| 76 | return ApiScoringBackend( |
| 77 | base_url=base_spec.endpoint, |
| 78 | model_name=base_spec.base, |
| 79 | ) |
| 80 | |
| 81 | if base_spec.kind == "custom": |
| 82 | return _load_custom(base_spec, effective_adapter) |
| 83 | |
| 84 | raise SpecValidationError(f"unknown backend kind: {base_spec.kind!r}") |
| 85 | |
| 86 | |
| 87 | def _load_custom(base_spec: ModelSpec, adapter: Path | None) -> DifferentialBackend: |
| 88 | """Dispatch to a user-supplied backend via ``entry_point='pkg.mod:Name'``. |
| 89 | |
| 90 | The imported class is instantiated as ``Cls(base_spec=..., adapter_path=...)`` |
| 91 | — the same signature as :class:`dlm_sway.backends.hf.HuggingFaceDifferentialBackend` |
| 92 | so authors can model their implementation on the built-in. The |
| 93 | result is runtime-checked against :class:`DifferentialBackend` and |
| 94 | the optional :class:`NullCalibratedBackend` / |
| 95 | :class:`ScalableDifferentialBackend` so protocol violations fail |
| 96 | at construction, not deep inside a probe (B20). The set of |
| 97 | satisfied protocols is recorded on the instance as |
| 98 | ``__sway_protocols__: tuple[str, ...]`` for the report's |
| 99 | backend-info section. |
| 100 | """ |
| 101 | from dlm_sway.core.scoring import ( |
| 102 | DifferentialBackend as DiffBackend, |
| 103 | ) |
| 104 | from dlm_sway.core.scoring import ( |
| 105 | NullCalibratedBackend, |
| 106 | ScalableDifferentialBackend, |
| 107 | ) |
| 108 | |
| 109 | entry = base_spec.entry_point |
| 110 | if not entry: |
| 111 | raise SpecValidationError( |
| 112 | "kind='custom' requires an entry_point of the form 'pkg.module:ClassName'" |
| 113 | ) |
| 114 | if ":" not in entry: |
| 115 | raise SpecValidationError(f"entry_point must be 'pkg.module:ClassName', got {entry!r}") |
| 116 | module_path, _, class_name = entry.partition(":") |
| 117 | if not module_path or not class_name: |
| 118 | raise SpecValidationError(f"entry_point must be 'pkg.module:ClassName', got {entry!r}") |
| 119 | |
| 120 | import importlib |
| 121 | |
| 122 | try: |
| 123 | module = importlib.import_module(module_path) |
| 124 | except ImportError as exc: |
| 125 | raise SpecValidationError( |
| 126 | f"custom backend: cannot import module {module_path!r}: {exc}" |
| 127 | ) from exc |
| 128 | cls = getattr(module, class_name, None) |
| 129 | if cls is None: |
| 130 | raise SpecValidationError( |
| 131 | f"custom backend: module {module_path!r} has no attribute {class_name!r}" |
| 132 | ) |
| 133 | |
| 134 | try: |
| 135 | instance = cls(base_spec=base_spec, adapter_path=adapter) |
| 136 | except TypeError as exc: |
| 137 | raise SpecValidationError( |
| 138 | f"custom backend {entry!r} constructor signature mismatch: {exc}. " |
| 139 | "Expected Cls(base_spec: ModelSpec, adapter_path: Path | None)" |
| 140 | ) from exc |
| 141 | |
| 142 | if not isinstance(instance, DiffBackend): |
| 143 | raise SpecValidationError( |
| 144 | f"custom backend {entry!r} does not satisfy DifferentialBackend " |
| 145 | "(needs as_base() and as_finetuned() context managers)" |
| 146 | ) |
| 147 | |
| 148 | # B20: probe optional protocols and record them so the runner / |
| 149 | # report can show which downstream features are available without |
| 150 | # repeated isinstance checks. |
| 151 | satisfied: list[str] = ["DifferentialBackend"] |
| 152 | if isinstance(instance, NullCalibratedBackend): |
| 153 | satisfied.append("NullCalibratedBackend") |
| 154 | if isinstance(instance, ScalableDifferentialBackend): |
| 155 | satisfied.append("ScalableDifferentialBackend") |
| 156 | instance.__sway_protocols__ = tuple(satisfied) # type: ignore[attr-defined] |
| 157 | return instance |
| 158 | |
| 159 | |
| 160 | from dlm_sway.backends.two_model import TwoModelDifferential, build_two_separate # noqa: E402 |
| 161 | |
| 162 | __all__ = ["TwoModelDifferential", "build", "build_two_separate"] |