| 1 |
"""Shared Modelfile directive builders. |
| 2 |
|
| 3 |
Used by both `modelfile.py` (text-only path) and `vl_modelfile.py` |
| 4 |
(vision-language path). Hoisted out so the VL renderer doesn't reach |
| 5 |
across `_`-prefixed names into its sibling module and no longer needs |
| 6 |
a ModelfileContext adapter bridge. |
| 7 |
|
| 8 |
All public helpers here take the minimal field set they read. Both |
| 9 |
call sites pack their own context shape and hand field values in — |
| 10 |
the renderers keep their own frozen dataclasses as the authoritative |
| 11 |
input type; the helpers stay input-shape-agnostic. |
| 12 |
|
| 13 |
Security note: `build_system_line` + `build_license_line` pass their |
| 14 |
strings through `json.dumps`. Ollama's Modelfile grammar accepts JSON |
| 15 |
string-literal escapes verbatim (`\\"`, `\\n`, `\\\\`), so a hostile |
| 16 |
prompt surfaces as content rather than metaparse. See `modelfile.py` |
| 17 |
module docstring for the full rationale. |
| 18 |
""" |
| 19 |
|
| 20 |
from __future__ import annotations |
| 21 |
|
| 22 |
import json |
| 23 |
from datetime import UTC, datetime |
| 24 |
from pathlib import Path |
| 25 |
from typing import TYPE_CHECKING |
| 26 |
|
| 27 |
from dlm.export.ollama.errors import ModelfileError |
| 28 |
|
| 29 |
if TYPE_CHECKING: |
| 30 |
from dlm.base_models import BaseModelSpec |
| 31 |
from dlm.export.ollama.template_registry import DialectTemplate |
| 32 |
|
| 33 |
|
| 34 |
def build_header( |
| 35 |
*, |
| 36 |
dlm_version: str, |
| 37 |
dlm_id: str, |
| 38 |
adapter_version: int, |
| 39 |
base_key: str, |
| 40 |
base_revision: str, |
| 41 |
quant: str, |
| 42 |
merged: bool, |
| 43 |
source_dlm_path: Path | None = None, |
| 44 |
) -> str: |
| 45 |
"""Top-of-file `# Generated by dlm …` comment block.""" |
| 46 |
now = datetime.now(UTC).replace(tzinfo=None, microsecond=0).isoformat() |
| 47 |
lines = [f"# Generated by dlm {dlm_version} on {now}"] |
| 48 |
if source_dlm_path is not None: |
| 49 |
lines.append(f"# Source: {source_dlm_path}") |
| 50 |
lines.extend( |
| 51 |
[ |
| 52 |
f"# dlm_id: {dlm_id}", |
| 53 |
f"# adapter_version: {adapter_version}", |
| 54 |
f"# base_model: {base_key} (revision {base_revision})", |
| 55 |
f"# quant: {quant}", |
| 56 |
f"# merged: {merged}", |
| 57 |
] |
| 58 |
) |
| 59 |
return "\n".join(lines) |
| 60 |
|
| 61 |
|
| 62 |
def resolve_stops(adapter_dir: Path, template_row: DialectTemplate) -> tuple[str, ...]: |
| 63 |
"""Union of dialect defaults + EOS/added-tokens from the adapter tokenizer. |
| 64 |
|
| 65 |
Per the tokenizer contract: added special tokens from the adapter |
| 66 |
tokenizer become additional stops. Without this, a pad-token-grown |
| 67 |
model emits `<|pad|>` indefinitely. |
| 68 |
""" |
| 69 |
merged: list[str] = list(template_row.default_stops) |
| 70 |
merged.extend(template_row.extra_stop_hints) |
| 71 |
|
| 72 |
adapter_stops = _read_adapter_stops(adapter_dir) |
| 73 |
for tok in adapter_stops: |
| 74 |
if tok and tok not in merged: |
| 75 |
merged.append(tok) |
| 76 |
|
| 77 |
seen: set[str] = set() |
| 78 |
unique: list[str] = [] |
| 79 |
for tok in merged: |
| 80 |
if tok not in seen: |
| 81 |
seen.add(tok) |
| 82 |
unique.append(tok) |
| 83 |
return tuple(unique) |
| 84 |
|
| 85 |
|
| 86 |
def _read_adapter_stops(adapter_dir: Path) -> list[str]: |
| 87 |
"""Pull `eos_token` + added-tokens from the adapter tokenizer config.""" |
| 88 |
cfg_path = adapter_dir / "tokenizer_config.json" |
| 89 |
if not cfg_path.exists(): |
| 90 |
return [] |
| 91 |
try: |
| 92 |
cfg = json.loads(cfg_path.read_text(encoding="utf-8")) |
| 93 |
except (OSError, json.JSONDecodeError) as exc: |
| 94 |
raise ModelfileError( |
| 95 |
f"adapter tokenizer_config.json at {cfg_path} is unreadable: {exc}" |
| 96 |
) from exc |
| 97 |
|
| 98 |
stops: list[str] = [] |
| 99 |
eos = cfg.get("eos_token") |
| 100 |
if isinstance(eos, str) and eos: |
| 101 |
stops.append(eos) |
| 102 |
elif isinstance(eos, dict) and isinstance(eos.get("content"), str): |
| 103 |
stops.append(eos["content"]) |
| 104 |
|
| 105 |
added = cfg.get("added_tokens_decoder") or {} |
| 106 |
if isinstance(added, dict): |
| 107 |
for entry in added.values(): |
| 108 |
if isinstance(entry, dict) and entry.get("special") is True: |
| 109 |
content = entry.get("content") |
| 110 |
if isinstance(content, str): |
| 111 |
stops.append(content) |
| 112 |
return stops |
| 113 |
|
| 114 |
|
| 115 |
def build_param_lines( |
| 116 |
*, |
| 117 |
stops: tuple[str, ...], |
| 118 |
temperature: float, |
| 119 |
top_p: float, |
| 120 |
num_ctx: int | None, |
| 121 |
draft_model: str | None, |
| 122 |
) -> list[str]: |
| 123 |
"""Emit the `PARAMETER stop …` + sampling defaults block.""" |
| 124 |
lines: list[str] = [] |
| 125 |
for stop in stops: |
| 126 |
lines.append(f"PARAMETER stop {json.dumps(stop)}") |
| 127 |
lines.append(f"PARAMETER temperature {temperature}") |
| 128 |
lines.append(f"PARAMETER top_p {top_p}") |
| 129 |
if num_ctx is not None: |
| 130 |
lines.append(f"PARAMETER num_ctx {num_ctx}") |
| 131 |
if draft_model is not None: |
| 132 |
lines.append(f"# Speculative decoding: `ollama pull {draft_model}` first.") |
| 133 |
lines.append(f"PARAMETER draft_model {draft_model}") |
| 134 |
return lines |
| 135 |
|
| 136 |
|
| 137 |
def resolve_num_ctx( |
| 138 |
training_sequence_len: int | None, |
| 139 |
spec_context_length: int, |
| 140 |
) -> int | None: |
| 141 |
"""Cap `training_sequence_len` at the base spec's `context_length`. |
| 142 |
|
| 143 |
Returns `None` when the document didn't pin a length — Ollama's |
| 144 |
2048 default applies. Otherwise the capped length so a document |
| 145 |
trained at 8192 gets the window it expects without exceeding the |
| 146 |
base model's positional-embedding table. |
| 147 |
""" |
| 148 |
if training_sequence_len is None: |
| 149 |
return None |
| 150 |
return min(training_sequence_len, spec_context_length) |
| 151 |
|
| 152 |
|
| 153 |
def build_system_line(system_prompt: str | None) -> str | None: |
| 154 |
"""JSON-escaped `SYSTEM "…"` directive, or `None` when no prompt.""" |
| 155 |
if system_prompt is None: |
| 156 |
return None |
| 157 |
stripped = system_prompt.strip() |
| 158 |
if not stripped: |
| 159 |
return None |
| 160 |
return f"SYSTEM {json.dumps(stripped)}" |
| 161 |
|
| 162 |
|
| 163 |
def build_license_line(spec: BaseModelSpec) -> str | None: |
| 164 |
"""`LICENSE "…"` directive from the base model's SPDX id.""" |
| 165 |
if not spec.license_spdx: |
| 166 |
return None |
| 167 |
return f"LICENSE {json.dumps(spec.license_spdx)}" |