Python · 3002 bytes Raw Blame History
1 """Prompt templates for synthetic instruction generation."""
2
3 from __future__ import annotations
4
5 from typing import Final, Literal
6
7 from jinja2 import StrictUndefined, Template
8 from pydantic import BaseModel, ConfigDict, Field, model_validator
9
10 SynthStrategy = Literal["extraction", "expansion", "both"]
11 PromptParserKind = Literal["json_list", "numbered_list"]
12
13
14 class SynthPromptTemplate(BaseModel):
15 """One shipped or user-supplied synth prompt template."""
16
17 model_config = ConfigDict(extra="forbid", frozen=True)
18
19 system_prompt: str = Field(..., min_length=1)
20 user_template: str = Field(..., min_length=1)
21 output_parser: PromptParserKind = "json_list"
22
23 @model_validator(mode="after")
24 def _template_mentions_required_vars(self) -> SynthPromptTemplate:
25 missing: list[str] = []
26 if "{{ prose }}" not in self.user_template:
27 missing.append("prose")
28 if "{{ n }}" not in self.user_template:
29 missing.append("n")
30 if missing:
31 raise ValueError(f"user_template must reference required variable(s) {missing!r}")
32 return self
33
34 def render_user_prompt(self, *, prose: str, n: int) -> str:
35 """Render the user prompt with strict variable handling."""
36 template = Template(self.user_template, undefined=StrictUndefined)
37 return template.render(prose=prose, n=n).strip()
38
39
40 DEFAULT_PROMPT_TEMPLATES: Final[dict[Literal["extraction", "expansion"], SynthPromptTemplate]] = {
41 "extraction": SynthPromptTemplate(
42 system_prompt=(
43 "You generate high-quality single-turn instruction data from prose. "
44 "Only write questions that the prose directly answers. Return a JSON "
45 "list of objects with keys `question` and `answer`."
46 ),
47 user_template=(
48 "Given the prose below, generate {{ n }} factual question/answer pairs "
49 "whose answers are explicitly supported by the prose.\n\n"
50 "Return only a JSON list.\n\n"
51 "Prose:\n{{ prose }}"
52 ),
53 output_parser="json_list",
54 ),
55 "expansion": SynthPromptTemplate(
56 system_prompt=(
57 "You generate high-quality single-turn instruction data from prose. "
58 "Write curious but grounded follow-up questions a reader might ask, "
59 "and answer them as helpfully as possible. Return a JSON list of "
60 "objects with keys `question` and `answer`."
61 ),
62 user_template=(
63 "Given the prose below, generate {{ n }} question/answer pairs that "
64 "expand on the material in a useful way without contradicting it.\n\n"
65 "Return only a JSON list.\n\n"
66 "Prose:\n{{ prose }}"
67 ),
68 output_parser="json_list",
69 ),
70 }
71
72
73 def get_prompt_template(strategy: Literal["extraction", "expansion"]) -> SynthPromptTemplate:
74 """Return the shipped prompt template for one synth strategy."""
75 return DEFAULT_PROMPT_TEMPLATES[strategy]