| 1 | """The :class:`Model` abstraction and :class:`ModelSpec` user-facing config. |
| 2 | |
| 3 | Probes operate on objects that satisfy :class:`Model` (for generation) |
| 4 | and :class:`~dlm_sway.core.scoring.ScoringBackend` (for logit-level |
| 5 | access). Backends return concrete instances of both — they are |
| 6 | deliberately separate Protocols because not every backend exposes logits |
| 7 | (e.g. an Ollama HTTP backend would implement ``Model`` but not |
| 8 | ``ScoringBackend``). |
| 9 | |
| 10 | The user-facing surface is :class:`ModelSpec`, a pydantic model that |
| 11 | describes how to materialize a base + adapter pair. No ``.dlm`` |
| 12 | concepts live at this layer — those belong in |
| 13 | :mod:`dlm_sway.integrations.dlm`. |
| 14 | """ |
| 15 | |
| 16 | from __future__ import annotations |
| 17 | |
| 18 | from dataclasses import dataclass |
| 19 | from pathlib import Path |
| 20 | from typing import Any, Literal, Protocol, runtime_checkable |
| 21 | |
| 22 | from pydantic import BaseModel, ConfigDict, Field, field_validator |
| 23 | |
| 24 | BackendKind = Literal["hf", "mlx", "api", "dummy", "custom"] |
| 25 | """Registered scoring-backend kinds. |
| 26 | |
| 27 | ``api`` targets an OpenAI-compatible HTTP endpoint (OpenAI, vLLM |
| 28 | serve, Ollama). ``custom`` is an escape hatch — the runner looks up |
| 29 | an entry point when it sees ``custom`` in a spec. |
| 30 | """ |
| 31 | |
| 32 | |
| 33 | class ModelSpec(BaseModel): |
| 34 | """How to materialize one model (base or fine-tuned).""" |
| 35 | |
| 36 | model_config = ConfigDict(extra="forbid", frozen=True) |
| 37 | |
| 38 | kind: BackendKind = "hf" |
| 39 | base: str |
| 40 | """HuggingFace repo id (``HuggingFaceTB/SmolLM2-135M-Instruct``) or |
| 41 | a local path to a model directory.""" |
| 42 | |
| 43 | adapter: Path | None = None |
| 44 | """Path to a PEFT adapter directory (containing ``adapter_config.json`` |
| 45 | and ``adapter_model.safetensors``). ``None`` → base-only model.""" |
| 46 | |
| 47 | dtype: Literal["auto", "fp16", "bf16", "fp32"] = "auto" |
| 48 | device: str = "auto" |
| 49 | """``"auto"`` chooses CUDA → MPS → CPU in that order.""" |
| 50 | |
| 51 | trust_remote_code: bool = False |
| 52 | """HuggingFace ``trust_remote_code`` passthrough. Off by default — |
| 53 | the user must opt in explicitly, matching sway's no-surprises |
| 54 | posture.""" |
| 55 | |
| 56 | entry_point: str | None = Field(default=None) |
| 57 | """Required when ``kind='custom'``. Import path like |
| 58 | ``mypkg.mybackend:MyBackend``.""" |
| 59 | |
| 60 | endpoint: str | None = Field(default=None) |
| 61 | """Required when ``kind='api'``. Base URL of an OpenAI-compatible |
| 62 | completions server (``https://api.openai.com``, |
| 63 | ``http://localhost:11434`` for Ollama, or wherever ``vllm serve`` |
| 64 | listens). The ``/v1/completions`` path is appended by the backend. |
| 65 | |
| 66 | The API key comes from the environment (``SWAY_API_KEY``, falling |
| 67 | back to ``OPENAI_API_KEY``) so secrets don't live in the YAML spec.""" |
| 68 | |
| 69 | @field_validator("adapter") |
| 70 | @classmethod |
| 71 | def _normalize_adapter_path(cls, v: Path | None) -> Path | None: |
| 72 | """Expand ``~`` and resolve relative segments at spec-load time. |
| 73 | |
| 74 | Before B22 every backend re-did this work in its constructor; |
| 75 | normalizing once at the spec boundary means the cache key in |
| 76 | :func:`dlm_sway.probes._null_cache.compute_key` (which encodes |
| 77 | the adapter path) is stable regardless of how the user spelled |
| 78 | it in YAML or on the CLI. |
| 79 | """ |
| 80 | if v is None: |
| 81 | return None |
| 82 | return Path(v).expanduser().resolve() |
| 83 | |
| 84 | |
| 85 | @dataclass(frozen=True, slots=True) |
| 86 | class LoadedModel: |
| 87 | """A materialized model plus the tokenizer that produced it. |
| 88 | |
| 89 | Returned by backend ``load()`` methods. Probes usually don't touch |
| 90 | this directly — they go through the :class:`Model` / |
| 91 | :class:`~dlm_sway.core.scoring.ScoringBackend` Protocols. |
| 92 | """ |
| 93 | |
| 94 | id: str |
| 95 | """Stable handle: ``"base"`` or ``"ft"`` typically.""" |
| 96 | spec: ModelSpec |
| 97 | model: Any |
| 98 | """Framework-native handle (torch ``nn.Module``, MLX array module …). |
| 99 | |
| 100 | Typed as ``Any`` because the frameworks themselves ship unstubbed. |
| 101 | Backend implementations narrow this at their boundary.""" |
| 102 | tokenizer: Any |
| 103 | meta: dict[str, Any] |
| 104 | """Backend-captured metadata: device, dtype, adapter version, bytes |
| 105 | on disk, num trainable params. Surfaced in the suite report.""" |
| 106 | |
| 107 | |
| 108 | @runtime_checkable |
| 109 | class Model(Protocol): |
| 110 | """Minimum interface for text generation. |
| 111 | |
| 112 | Implemented by backend-wrapped model objects. Probes that need logits |
| 113 | also require :class:`~dlm_sway.core.scoring.ScoringBackend`. |
| 114 | """ |
| 115 | |
| 116 | id: str |
| 117 | |
| 118 | def generate( |
| 119 | self, |
| 120 | prompt: str, |
| 121 | *, |
| 122 | max_new_tokens: int, |
| 123 | temperature: float = 0.0, |
| 124 | top_p: float = 1.0, |
| 125 | seed: int = 0, |
| 126 | ) -> str: |
| 127 | """Generate a completion. |
| 128 | |
| 129 | Defaults (``temperature=0``, ``top_p=1``) are greedy-decode for |
| 130 | reproducibility. Callers wanting sampled output must pass |
| 131 | non-defaults *and* a seed. |
| 132 | """ |
| 133 | ... |
| 134 | |
| 135 | def close(self) -> None: |
| 136 | """Release any resources held by this model.""" |
| 137 | ... |