Python · 4820 bytes Raw Blame History
1 """The :class:`Model` abstraction and :class:`ModelSpec` user-facing config.
2
3 Probes operate on objects that satisfy :class:`Model` (for generation)
4 and :class:`~dlm_sway.core.scoring.ScoringBackend` (for logit-level
5 access). Backends return concrete instances of both — they are
6 deliberately separate Protocols because not every backend exposes logits
7 (e.g. an Ollama HTTP backend would implement ``Model`` but not
8 ``ScoringBackend``).
9
10 The user-facing surface is :class:`ModelSpec`, a pydantic model that
11 describes how to materialize a base + adapter pair. No ``.dlm``
12 concepts live at this layer — those belong in
13 :mod:`dlm_sway.integrations.dlm`.
14 """
15
16 from __future__ import annotations
17
18 from dataclasses import dataclass
19 from pathlib import Path
20 from typing import Any, Literal, Protocol, runtime_checkable
21
22 from pydantic import BaseModel, ConfigDict, Field, field_validator
23
24 BackendKind = Literal["hf", "mlx", "api", "dummy", "custom"]
25 """Registered scoring-backend kinds.
26
27 ``api`` targets an OpenAI-compatible HTTP endpoint (OpenAI, vLLM
28 serve, Ollama). ``custom`` is an escape hatch — the runner looks up
29 an entry point when it sees ``custom`` in a spec.
30 """
31
32
33 class ModelSpec(BaseModel):
34 """How to materialize one model (base or fine-tuned)."""
35
36 model_config = ConfigDict(extra="forbid", frozen=True)
37
38 kind: BackendKind = "hf"
39 base: str
40 """HuggingFace repo id (``HuggingFaceTB/SmolLM2-135M-Instruct``) or
41 a local path to a model directory."""
42
43 adapter: Path | None = None
44 """Path to a PEFT adapter directory (containing ``adapter_config.json``
45 and ``adapter_model.safetensors``). ``None`` → base-only model."""
46
47 dtype: Literal["auto", "fp16", "bf16", "fp32"] = "auto"
48 device: str = "auto"
49 """``"auto"`` chooses CUDA → MPS → CPU in that order."""
50
51 trust_remote_code: bool = False
52 """HuggingFace ``trust_remote_code`` passthrough. Off by default —
53 the user must opt in explicitly, matching sway's no-surprises
54 posture."""
55
56 entry_point: str | None = Field(default=None)
57 """Required when ``kind='custom'``. Import path like
58 ``mypkg.mybackend:MyBackend``."""
59
60 endpoint: str | None = Field(default=None)
61 """Required when ``kind='api'``. Base URL of an OpenAI-compatible
62 completions server (``https://api.openai.com``,
63 ``http://localhost:11434`` for Ollama, or wherever ``vllm serve``
64 listens). The ``/v1/completions`` path is appended by the backend.
65
66 The API key comes from the environment (``SWAY_API_KEY``, falling
67 back to ``OPENAI_API_KEY``) so secrets don't live in the YAML spec."""
68
69 @field_validator("adapter")
70 @classmethod
71 def _normalize_adapter_path(cls, v: Path | None) -> Path | None:
72 """Expand ``~`` and resolve relative segments at spec-load time.
73
74 Before B22 every backend re-did this work in its constructor;
75 normalizing once at the spec boundary means the cache key in
76 :func:`dlm_sway.probes._null_cache.compute_key` (which encodes
77 the adapter path) is stable regardless of how the user spelled
78 it in YAML or on the CLI.
79 """
80 if v is None:
81 return None
82 return Path(v).expanduser().resolve()
83
84
85 @dataclass(frozen=True, slots=True)
86 class LoadedModel:
87 """A materialized model plus the tokenizer that produced it.
88
89 Returned by backend ``load()`` methods. Probes usually don't touch
90 this directly — they go through the :class:`Model` /
91 :class:`~dlm_sway.core.scoring.ScoringBackend` Protocols.
92 """
93
94 id: str
95 """Stable handle: ``"base"`` or ``"ft"`` typically."""
96 spec: ModelSpec
97 model: Any
98 """Framework-native handle (torch ``nn.Module``, MLX array module …).
99
100 Typed as ``Any`` because the frameworks themselves ship unstubbed.
101 Backend implementations narrow this at their boundary."""
102 tokenizer: Any
103 meta: dict[str, Any]
104 """Backend-captured metadata: device, dtype, adapter version, bytes
105 on disk, num trainable params. Surfaced in the suite report."""
106
107
108 @runtime_checkable
109 class Model(Protocol):
110 """Minimum interface for text generation.
111
112 Implemented by backend-wrapped model objects. Probes that need logits
113 also require :class:`~dlm_sway.core.scoring.ScoringBackend`.
114 """
115
116 id: str
117
118 def generate(
119 self,
120 prompt: str,
121 *,
122 max_new_tokens: int,
123 temperature: float = 0.0,
124 top_p: float = 1.0,
125 seed: int = 0,
126 ) -> str:
127 """Generate a completion.
128
129 Defaults (``temperature=0``, ``top_p=1``) are greedy-decode for
130 reproducibility. Callers wanting sampled output must pass
131 non-defaults *and* a seed.
132 """
133 ...
134
135 def close(self) -> None:
136 """Release any resources held by this model."""
137 ...