Python · 6752 bytes Raw Blame History
1 """Dialect → Go `text/template` registry.
2
3 One `.gotmpl` file per chat-template dialect, shipped beside this
4 module. The registry also records per-dialect:
5
6 - **Default stops** — token strings that Ollama must refuse to
7 generate past. Missing stops cause runaway generation (findings §9).
8 - **Default params** — `temperature` + `top_p` tuned per family.
9
10 The templates themselves are Go `text/template` syntax: `{{.Role}}`,
11 `{{range .Messages}}`, `{{if .System}}`, etc. Ollama renders them
12 at inference time. The round-trip test (`tests/unit/export/ollama/
13 test_template_registry_roundtrip.py`) verifies each template
14 produces token-identical output to the base model's Jinja reference
15 on a fixed message-set matrix.
16 """
17
18 from __future__ import annotations
19
20 from dataclasses import dataclass, field
21 from pathlib import Path
22 from typing import Final, Literal
23
24 from dlm.export.ollama.errors import TemplateRegistryError
25
26 Dialect = Literal[
27 "chatml",
28 "qwen3thinking",
29 "gemma2",
30 "smollm3",
31 "olmo2",
32 "llama3",
33 "phi3",
34 "phi4mini",
35 "mistral",
36 ]
37
38 _TEMPLATES_DIR: Final[Path] = Path(__file__).resolve().parent / "templates"
39
40
41 @dataclass(frozen=True)
42 class DialectTemplate:
43 """One registry row: a dialect's Go template + param defaults."""
44
45 dialect: Dialect
46 template_path: Path
47 default_stops: tuple[str, ...]
48 default_temperature: float = 0.7
49 default_top_p: float = 0.9
50 # Added-by-training special tokens go here ONLY if they're
51 # dialect-inherent (e.g., chatml's `<|im_end|>`). Per-adapter added
52 # tokens from pad fallback come from the adapter
53 # tokenizer at render time.
54 extra_stop_hints: tuple[str, ...] = field(default_factory=tuple)
55
56 def read_template(self) -> str:
57 """Return the `.gotmpl` file contents verbatim."""
58 if not self.template_path.is_file():
59 raise TemplateRegistryError(
60 f"template file missing for {self.dialect!r}: {self.template_path}"
61 )
62 return self.template_path.read_text(encoding="utf-8")
63
64
65 _REGISTRY: Final[dict[Dialect, DialectTemplate]] = {
66 "chatml": DialectTemplate(
67 dialect="chatml",
68 template_path=_TEMPLATES_DIR / "chatml.gotmpl",
69 # Qwen 2.5 variants emit role-delimiter tokens like
70 # `<|im_start|>` at the top of each turn. Listing them as stops
71 # prevents runaway prompt-continuation when the model tries to
72 # synthesize a new turn instead of yielding.
73 default_stops=("<|im_end|>", "<|endoftext|>", "<|im_start|>"),
74 ),
75 "qwen3thinking": DialectTemplate(
76 dialect="qwen3thinking",
77 template_path=_TEMPLATES_DIR / "qwen3thinking.gotmpl",
78 # Qwen3's reasoning profile still uses ChatML turn framing, but
79 # the upstream defaults run slightly broader sampling than the
80 # legacy ChatML family.
81 default_stops=("<|im_end|>", "<|endoftext|>", "<|im_start|>"),
82 default_temperature=0.6,
83 default_top_p=0.95,
84 ),
85 "gemma2": DialectTemplate(
86 dialect="gemma2",
87 template_path=_TEMPLATES_DIR / "gemma2.gotmpl",
88 # Gemma 2 instruct uses `<start_of_turn>` / `<end_of_turn>`
89 # framing and a trailing `<start_of_turn>model` generation
90 # prompt. Stop on both turn-boundary tokens and `<eos>` so
91 # Ollama yields instead of continuing into a fresh role block.
92 default_stops=("<end_of_turn>", "<eos>", "<start_of_turn>"),
93 ),
94 "smollm3": DialectTemplate(
95 dialect="smollm3",
96 template_path=_TEMPLATES_DIR / "smollm3.gotmpl",
97 # SmolLM3 keeps ChatML turn framing but ships a reasoning-first
98 # instruct prompt. Stop on the turn delimiters so Ollama
99 # yields instead of hallucinating another role block.
100 default_stops=("<|im_end|>", "<|end_of_text|>", "<|im_start|>"),
101 default_temperature=0.6,
102 default_top_p=0.95,
103 ),
104 "olmo2": DialectTemplate(
105 dialect="olmo2",
106 template_path=_TEMPLATES_DIR / "olmo2.gotmpl",
107 # OLMo2 uses `<|user|>` / `<|assistant|>` role markers and
108 # closes assistant turns with `<|endoftext|>`.
109 default_stops=("<|endoftext|>", "<|user|>", "<|assistant|>", "<|system|>"),
110 ),
111 "llama3": DialectTemplate(
112 dialect="llama3",
113 template_path=_TEMPLATES_DIR / "llama3.gotmpl",
114 # Llama-3 instruct uses `<|start_header_id|>` to begin a role
115 # block; include it so the model can't generate a spurious new
116 # turn. `<|eom_id|>` is Llama-3.1's tool-call terminator.
117 default_stops=(
118 "<|eot_id|>",
119 "<|end_of_text|>",
120 "<|start_header_id|>",
121 "<|eom_id|>",
122 ),
123 ),
124 "phi3": DialectTemplate(
125 dialect="phi3",
126 template_path=_TEMPLATES_DIR / "phi3.gotmpl",
127 # Phi-3.5-mini's chat template uses `<|user|>`, `<|assistant|>`,
128 # `<|system|>` role delimiters — the audit called these out.
129 default_stops=(
130 "<|end|>",
131 "<|endoftext|>",
132 "<|user|>",
133 "<|assistant|>",
134 "<|system|>",
135 ),
136 ),
137 "phi4mini": DialectTemplate(
138 dialect="phi4mini",
139 template_path=_TEMPLATES_DIR / "phi4mini.gotmpl",
140 # Phi-4-mini-reasoning keeps Phi-style `<|role|>` markers and
141 # `<|end|>` closers, but the upstream template always injects a
142 # default system preamble before user turns.
143 default_stops=(
144 "<|end|>",
145 "<|endoftext|>",
146 "<|user|>",
147 "<|assistant|>",
148 "<|system|>",
149 ),
150 ),
151 "mistral": DialectTemplate(
152 dialect="mistral",
153 template_path=_TEMPLATES_DIR / "mistral.gotmpl",
154 # Mistral's instruct format wraps user turns in `[INST] ... [/INST]`;
155 # `[INST]` as a stop prevents the model from restarting a turn.
156 default_stops=("</s>", "[/INST]", "[INST]"),
157 ),
158 }
159
160
161 def registered_dialects() -> tuple[Dialect, ...]:
162 """Tuple of all shipped dialects. Useful for CLI help + parametrize."""
163 return tuple(_REGISTRY.keys())
164
165
166 def get_template(dialect: str) -> DialectTemplate:
167 """Return the registry row for `dialect` or raise.
168
169 `dialect` is typed as `str` at the edge so callers from pydantic
170 `Literal` fields can pass through; we validate against the
171 registry keys.
172 """
173 if dialect not in _REGISTRY:
174 raise TemplateRegistryError(
175 f"unknown template dialect {dialect!r}; registered: {sorted(_REGISTRY)}"
176 )
177 return _REGISTRY[dialect]
178
179
180 def load_template_text(dialect: str) -> str:
181 """Shortcut: `get_template(dialect).read_template()`."""
182 return get_template(dialect).read_template()