Python · 5690 bytes Raw Blame History
1 """Shared Modelfile directive builders.
2
3 Used by both `modelfile.py` (text-only path) and `vl_modelfile.py`
4 (vision-language path). Hoisted out so the VL renderer doesn't reach
5 across `_`-prefixed names into its sibling module and no longer needs
6 a ModelfileContext adapter bridge.
7
8 All public helpers here take the minimal field set they read. Both
9 call sites pack their own context shape and hand field values in —
10 the renderers keep their own frozen dataclasses as the authoritative
11 input type; the helpers stay input-shape-agnostic.
12
13 Security note: `build_system_line` + `build_license_line` pass their
14 strings through `json.dumps`. Ollama's Modelfile grammar accepts JSON
15 string-literal escapes verbatim (`\\"`, `\\n`, `\\\\`), so a hostile
16 prompt surfaces as content rather than metaparse. See `modelfile.py`
17 module docstring for the full rationale.
18 """
19
20 from __future__ import annotations
21
22 import json
23 from datetime import UTC, datetime
24 from pathlib import Path
25 from typing import TYPE_CHECKING
26
27 from dlm.export.ollama.errors import ModelfileError
28
29 if TYPE_CHECKING:
30 from dlm.base_models import BaseModelSpec
31 from dlm.export.ollama.template_registry import DialectTemplate
32
33
34 def build_header(
35 *,
36 dlm_version: str,
37 dlm_id: str,
38 adapter_version: int,
39 base_key: str,
40 base_revision: str,
41 quant: str,
42 merged: bool,
43 source_dlm_path: Path | None = None,
44 ) -> str:
45 """Top-of-file `# Generated by dlm …` comment block."""
46 now = datetime.now(UTC).replace(tzinfo=None, microsecond=0).isoformat()
47 lines = [f"# Generated by dlm {dlm_version} on {now}"]
48 if source_dlm_path is not None:
49 lines.append(f"# Source: {source_dlm_path}")
50 lines.extend(
51 [
52 f"# dlm_id: {dlm_id}",
53 f"# adapter_version: {adapter_version}",
54 f"# base_model: {base_key} (revision {base_revision})",
55 f"# quant: {quant}",
56 f"# merged: {merged}",
57 ]
58 )
59 return "\n".join(lines)
60
61
62 def resolve_stops(adapter_dir: Path, template_row: DialectTemplate) -> tuple[str, ...]:
63 """Union of dialect defaults + EOS/added-tokens from the adapter tokenizer.
64
65 Per the tokenizer contract: added special tokens from the adapter
66 tokenizer become additional stops. Without this, a pad-token-grown
67 model emits `<|pad|>` indefinitely.
68 """
69 merged: list[str] = list(template_row.default_stops)
70 merged.extend(template_row.extra_stop_hints)
71
72 adapter_stops = _read_adapter_stops(adapter_dir)
73 for tok in adapter_stops:
74 if tok and tok not in merged:
75 merged.append(tok)
76
77 seen: set[str] = set()
78 unique: list[str] = []
79 for tok in merged:
80 if tok not in seen:
81 seen.add(tok)
82 unique.append(tok)
83 return tuple(unique)
84
85
86 def _read_adapter_stops(adapter_dir: Path) -> list[str]:
87 """Pull `eos_token` + added-tokens from the adapter tokenizer config."""
88 cfg_path = adapter_dir / "tokenizer_config.json"
89 if not cfg_path.exists():
90 return []
91 try:
92 cfg = json.loads(cfg_path.read_text(encoding="utf-8"))
93 except (OSError, json.JSONDecodeError) as exc:
94 raise ModelfileError(
95 f"adapter tokenizer_config.json at {cfg_path} is unreadable: {exc}"
96 ) from exc
97
98 stops: list[str] = []
99 eos = cfg.get("eos_token")
100 if isinstance(eos, str) and eos:
101 stops.append(eos)
102 elif isinstance(eos, dict) and isinstance(eos.get("content"), str):
103 stops.append(eos["content"])
104
105 added = cfg.get("added_tokens_decoder") or {}
106 if isinstance(added, dict):
107 for entry in added.values():
108 if isinstance(entry, dict) and entry.get("special") is True:
109 content = entry.get("content")
110 if isinstance(content, str):
111 stops.append(content)
112 return stops
113
114
115 def build_param_lines(
116 *,
117 stops: tuple[str, ...],
118 temperature: float,
119 top_p: float,
120 num_ctx: int | None,
121 draft_model: str | None,
122 ) -> list[str]:
123 """Emit the `PARAMETER stop …` + sampling defaults block."""
124 lines: list[str] = []
125 for stop in stops:
126 lines.append(f"PARAMETER stop {json.dumps(stop)}")
127 lines.append(f"PARAMETER temperature {temperature}")
128 lines.append(f"PARAMETER top_p {top_p}")
129 if num_ctx is not None:
130 lines.append(f"PARAMETER num_ctx {num_ctx}")
131 if draft_model is not None:
132 lines.append(f"# Speculative decoding: `ollama pull {draft_model}` first.")
133 lines.append(f"PARAMETER draft_model {draft_model}")
134 return lines
135
136
137 def resolve_num_ctx(
138 training_sequence_len: int | None,
139 spec_context_length: int,
140 ) -> int | None:
141 """Cap `training_sequence_len` at the base spec's `context_length`.
142
143 Returns `None` when the document didn't pin a length — Ollama's
144 2048 default applies. Otherwise the capped length so a document
145 trained at 8192 gets the window it expects without exceeding the
146 base model's positional-embedding table.
147 """
148 if training_sequence_len is None:
149 return None
150 return min(training_sequence_len, spec_context_length)
151
152
153 def build_system_line(system_prompt: str | None) -> str | None:
154 """JSON-escaped `SYSTEM "…"` directive, or `None` when no prompt."""
155 if system_prompt is None:
156 return None
157 stripped = system_prompt.strip()
158 if not stripped:
159 return None
160 return f"SYSTEM {json.dumps(stripped)}"
161
162
163 def build_license_line(spec: BaseModelSpec) -> str | None:
164 """`LICENSE "…"` directive from the base model's SPDX id."""
165 if not spec.license_spdx:
166 return None
167 return f"LICENSE {json.dumps(spec.license_spdx)}"