Python · 4923 bytes Raw Blame History
1 """Two-model differential wrapper.
2
3 Used when ``sway.yaml`` sets ``defaults.differential: false``. Instead of
4 toggling one loaded model via PEFT, the runner loads two independent
5 backends and routes ``as_base()`` through one and ``as_finetuned()``
6 through the other.
7
8 This doubles memory (two full model copies in RAM) and halves throughput
9 — the flag exists for custom backends that can't do in-place adapter
10 toggling, not as a production setting. The differential toggle path
11 remains the recommended configuration.
12 """
13
14 from __future__ import annotations
15
16 from contextlib import contextmanager
17 from typing import TYPE_CHECKING, Any
18
19 if TYPE_CHECKING:
20 from collections.abc import Iterator
21
22 from dlm_sway.core.scoring import DifferentialBackend
23
24
25 class TwoModelDifferential:
26 """Wrap two independent backends as a single ``DifferentialBackend``.
27
28 ``as_base()`` delegates to ``base.as_base()``; ``as_finetuned()``
29 delegates to ``ft.as_finetuned()``. The two inner backends are held
30 for their lifetime — callers close them explicitly (or via
31 :meth:`close`) when finished.
32
33 The wrapper does **not** enforce view exclusion across the two
34 backends (each inner backend is responsible for its own exclusion,
35 but nothing stops a caller holding simultaneous base+ft contexts
36 that come from different backends). That's fine — the exclusion
37 invariant in :class:`DifferentialBackend` exists to protect the
38 toggle state of a single-model backend; two independent models
39 have no toggle to corrupt.
40 """
41
42 def __init__(self, base: DifferentialBackend, ft: DifferentialBackend) -> None:
43 self._base = base
44 self._ft = ft
45 # Compose the concurrency flag from both inner backends. The
46 # wrapper itself holds no shared mutable state — the toggle
47 # invariant :class:`DifferentialBackend` protects doesn't apply
48 # here — so the composed value is exactly as strong as the
49 # weaker of the two inner backends.
50 self.safe_for_concurrent_views: bool = bool(
51 getattr(self._base, "safe_for_concurrent_views", False)
52 and getattr(self._ft, "safe_for_concurrent_views", False)
53 )
54
55 @contextmanager
56 def as_base(self) -> Iterator[Any]:
57 with self._base.as_base() as view:
58 yield view
59
60 @contextmanager
61 def as_finetuned(self) -> Iterator[Any]:
62 with self._ft.as_finetuned() as view:
63 yield view
64
65 def close(self) -> None:
66 """Close both inner backends if they expose a ``close()`` method."""
67 for backend in (self._base, self._ft):
68 close = getattr(backend, "close", None)
69 if callable(close):
70 close()
71
72 def preflight_finite_check(self) -> tuple[bool, str]:
73 """Delegate preflight to the ft-side backend if it supports it.
74
75 The failure mode the preflight catches is "adapter weights are
76 NaN" — which lives on the ft side only. A base-only check
77 would validate the wrong thing.
78 """
79 fn = getattr(self._ft, "preflight_finite_check", None)
80 if not callable(fn):
81 return True, "ft backend does not support preflight"
82 ok, reason = fn()
83 return bool(ok), str(reason)
84
85
86 def _build_via_backends(spec_models: Any) -> tuple[DifferentialBackend, DifferentialBackend]:
87 """Materialize ``(base_backend, ft_backend)`` from a ``SuiteModels`` spec.
88
89 Uses :func:`dlm_sway.backends.build` for each side. The base-side
90 build fails if the user didn't supply an adapter path on the base
91 spec — in which case we raise a clear error pointing them at
92 ``differential: true``.
93 """
94 from dlm_sway.backends import build
95 from dlm_sway.core.errors import SpecValidationError
96
97 base_spec = spec_models.base
98 ft_spec = spec_models.ft
99
100 if base_spec.kind == "hf" and base_spec.adapter is None:
101 raise SpecValidationError(
102 "defaults.differential=false with an HF base requires the base "
103 "ModelSpec to carry an adapter path too (the HF backend loads "
104 "via PEFT). Either set `models.base.adapter` explicitly, or "
105 "switch to `defaults.differential: true` to use the single-load "
106 "toggle path."
107 )
108
109 base_backend = build(base_spec)
110 ft_backend = build(ft_spec)
111 return base_backend, ft_backend
112
113
114 def build_two_separate(spec_models: Any) -> TwoModelDifferential:
115 """Build a :class:`TwoModelDifferential` from a ``SuiteModels`` spec.
116
117 Front door used by the suite runner when
118 ``spec.defaults.differential`` is ``False``. Re-exported from
119 :mod:`dlm_sway.backends` for symmetry with :func:`backends.build`.
120 """
121 base_backend, ft_backend = _build_via_backends(spec_models)
122 return TwoModelDifferential(base=base_backend, ft=ft_backend)
123
124
125 __all__ = ["TwoModelDifferential", "build_two_separate"]