Python · 7138 bytes Raw Blame History
1 #!/usr/bin/env python3
2 """Regenerate Sprint 12.6's per-dialect chat-template token-count goldens.
3
4 For every registered Go-template dialect that has a representative
5 base spec in the registry, this script:
6
7 1. Loads the HF tokenizer for the representative base.
8 2. Walks the shared scenario matrix (`tests/golden/chat-templates/
9 scenarios.json`).
10 3. For each scenario, renders via `apply_chat_template(...,
11 add_generation_prompt=True, tokenize=True)` and records the
12 token count.
13 4. Writes `tests/golden/chat-templates/<dialect>/<scenario>.json`.
14
15 The emitted files are byte-identical across runs against the same
16 pinned base revision + transformers version — they power the
17 Sprint 12.6 closed-loop check (Ollama Go template's
18 `prompt_eval_count` must equal these HF counts).
19
20 Usage:
21 uv run python scripts/refresh-chat-template-goldens.py
22 uv run python scripts/refresh-chat-template-goldens.py --check
23 uv run python scripts/refresh-chat-template-goldens.py --dialect chatml
24
25 Requires a hot HF cache for each dialect's representative base —
26 `--dialect NAME` lets you refresh one at a time if others aren't cached.
27
28 `--check` exits 0 when every existing golden matches the freshly-
29 computed value; exits 1 + prints a diff otherwise. Used by the
30 weekly drift workflow and by operators validating a base bump.
31 """
32
33 from __future__ import annotations
34
35 import argparse
36 import json
37 import sys
38 from datetime import UTC, datetime
39 from pathlib import Path
40 from typing import Any
41
42 _REPO_ROOT = Path(__file__).resolve().parents[1]
43 _GOLDENS_ROOT = _REPO_ROOT / "tests" / "golden" / "chat-templates"
44 _SCENARIOS_PATH = _GOLDENS_ROOT / "scenarios.json"
45 _DIALECT_SPECS_PATH = _GOLDENS_ROOT / "dialect-specs.json"
46
47
48 def _load_scenarios() -> list[dict[str, Any]]:
49 blob: dict[str, Any] = json.loads(_SCENARIOS_PATH.read_text(encoding="utf-8"))
50 scenarios: list[dict[str, Any]] = blob["scenarios"]
51 return scenarios
52
53
54 def _load_dialect_specs() -> dict[str, str | None]:
55 blob = json.loads(_DIALECT_SPECS_PATH.read_text(encoding="utf-8"))
56 return {k: v for k, v in blob.items() if not k.startswith("_")}
57
58
59 def _golden_path(dialect: str, scenario_name: str) -> Path:
60 return _GOLDENS_ROOT / dialect / f"{scenario_name}.json"
61
62
63 def _compute_token_count(tokenizer: Any, messages: list[dict[str, str]]) -> int:
64 # `return_dict=False` makes HF return a plain `list[int]`; without it
65 # newer tokenizers hand back a `BatchEncoding` whose `len(...)` is
66 # the number of keys (2), not the number of tokens.
67 rendered = tokenizer.apply_chat_template(
68 messages,
69 add_generation_prompt=True,
70 tokenize=True,
71 return_dict=False,
72 )
73 return len(rendered)
74
75
76 def _load_tokenizer(registry_key: str) -> Any:
77 from dlm.base_models import BASE_MODELS
78
79 spec = BASE_MODELS[registry_key]
80 from transformers import AutoTokenizer
81
82 # `use_fast=True` is the default but we spell it for clarity —
83 # `apply_chat_template` behaves identically across fast/slow.
84 return AutoTokenizer.from_pretrained(
85 spec.hf_id,
86 revision=spec.revision,
87 use_fast=True,
88 trust_remote_code=False,
89 )
90
91
92 def _write_golden(
93 path: Path,
94 *,
95 dialect: str,
96 scenario: dict[str, Any],
97 registry_key: str,
98 token_count: int,
99 ) -> None:
100 path.parent.mkdir(parents=True, exist_ok=True)
101 blob: dict[str, Any] = {
102 "dialect": dialect,
103 "scenario": scenario["name"],
104 "representative_base": registry_key,
105 "messages": scenario["messages"],
106 "expected_hf_token_count": token_count,
107 "regenerated_at": datetime.now(UTC).replace(tzinfo=None, microsecond=0).isoformat(),
108 }
109 path.write_text(json.dumps(blob, indent=2) + "\n", encoding="utf-8")
110
111
112 def _read_recorded(path: Path) -> int | None:
113 if not path.is_file():
114 return None
115 try:
116 blob = json.loads(path.read_text(encoding="utf-8"))
117 except (OSError, json.JSONDecodeError):
118 return None
119 val = blob.get("expected_hf_token_count")
120 return val if isinstance(val, int) else None
121
122
123 def _refresh_dialect(
124 dialect: str,
125 registry_key: str | None,
126 scenarios: list[dict[str, Any]],
127 *,
128 check: bool,
129 ) -> tuple[int, int]:
130 """Return `(written_or_matched, drifted)` counts for reporting."""
131 if registry_key is None:
132 print(f"[skip] {dialect}: no representative base in registry")
133 return (0, 0)
134
135 print(f"[load] {dialect}: using {registry_key}")
136 tokenizer = _load_tokenizer(registry_key)
137
138 written = 0
139 drifted = 0
140 for scenario in scenarios:
141 target = _golden_path(dialect, scenario["name"])
142 actual = _compute_token_count(tokenizer, scenario["messages"])
143 recorded = _read_recorded(target)
144
145 if check:
146 if recorded is None:
147 print(f" [MISS] {scenario['name']}: no golden on disk")
148 drifted += 1
149 elif recorded != actual:
150 print(
151 f" [DRIFT] {scenario['name']}: "
152 f"golden={recorded} actual={actual} delta={actual - recorded:+d}"
153 )
154 drifted += 1
155 else:
156 written += 1
157 else:
158 _write_golden(
159 target,
160 dialect=dialect,
161 scenario=scenario,
162 registry_key=registry_key,
163 token_count=actual,
164 )
165 status = "=" if recorded == actual else "+"
166 print(f" [{status}] {scenario['name']}: {actual} tokens")
167 written += 1
168
169 return written, drifted
170
171
172 def main() -> int:
173 parser = argparse.ArgumentParser(description=__doc__)
174 parser.add_argument(
175 "--check",
176 action="store_true",
177 help="Exit non-zero on drift; don't write.",
178 )
179 parser.add_argument(
180 "--dialect",
181 help="Refresh only this dialect (default: all).",
182 )
183 args = parser.parse_args()
184
185 scenarios = _load_scenarios()
186 dialect_specs = _load_dialect_specs()
187 if args.dialect is not None:
188 if args.dialect not in dialect_specs:
189 print(
190 f"error: unknown dialect {args.dialect!r}; known: {sorted(dialect_specs)}",
191 file=sys.stderr,
192 )
193 return 2
194 dialect_specs = {args.dialect: dialect_specs[args.dialect]}
195
196 total_written = 0
197 total_drifted = 0
198 for dialect, registry_key in dialect_specs.items():
199 written, drifted = _refresh_dialect(dialect, registry_key, scenarios, check=args.check)
200 total_written += written
201 total_drifted += drifted
202
203 if args.check:
204 if total_drifted:
205 print(
206 f"\nFAIL: {total_drifted} golden(s) drifted. Run without "
207 "`--check` to regenerate, then review the diff.",
208 file=sys.stderr,
209 )
210 return 1
211 print(f"\nOK: {total_written} golden(s) match current tokenizers.")
212 else:
213 print(f"\nOK: {total_written} golden(s) written.")
214 return 0
215
216
217 if __name__ == "__main__":
218 sys.exit(main())