| 1 |
"""Prompt templates for synthetic instruction generation.""" |
| 2 |
|
| 3 |
from __future__ import annotations |
| 4 |
|
| 5 |
from typing import Final, Literal |
| 6 |
|
| 7 |
from jinja2 import StrictUndefined, Template |
| 8 |
from pydantic import BaseModel, ConfigDict, Field, model_validator |
| 9 |
|
| 10 |
SynthStrategy = Literal["extraction", "expansion", "both"] |
| 11 |
PromptParserKind = Literal["json_list", "numbered_list"] |
| 12 |
|
| 13 |
|
| 14 |
class SynthPromptTemplate(BaseModel): |
| 15 |
"""One shipped or user-supplied synth prompt template.""" |
| 16 |
|
| 17 |
model_config = ConfigDict(extra="forbid", frozen=True) |
| 18 |
|
| 19 |
system_prompt: str = Field(..., min_length=1) |
| 20 |
user_template: str = Field(..., min_length=1) |
| 21 |
output_parser: PromptParserKind = "json_list" |
| 22 |
|
| 23 |
@model_validator(mode="after") |
| 24 |
def _template_mentions_required_vars(self) -> SynthPromptTemplate: |
| 25 |
missing: list[str] = [] |
| 26 |
if "{{ prose }}" not in self.user_template: |
| 27 |
missing.append("prose") |
| 28 |
if "{{ n }}" not in self.user_template: |
| 29 |
missing.append("n") |
| 30 |
if missing: |
| 31 |
raise ValueError(f"user_template must reference required variable(s) {missing!r}") |
| 32 |
return self |
| 33 |
|
| 34 |
def render_user_prompt(self, *, prose: str, n: int) -> str: |
| 35 |
"""Render the user prompt with strict variable handling.""" |
| 36 |
template = Template(self.user_template, undefined=StrictUndefined) |
| 37 |
return template.render(prose=prose, n=n).strip() |
| 38 |
|
| 39 |
|
| 40 |
DEFAULT_PROMPT_TEMPLATES: Final[dict[Literal["extraction", "expansion"], SynthPromptTemplate]] = { |
| 41 |
"extraction": SynthPromptTemplate( |
| 42 |
system_prompt=( |
| 43 |
"You generate high-quality single-turn instruction data from prose. " |
| 44 |
"Only write questions that the prose directly answers. Return a JSON " |
| 45 |
"list of objects with keys `question` and `answer`." |
| 46 |
), |
| 47 |
user_template=( |
| 48 |
"Given the prose below, generate {{ n }} factual question/answer pairs " |
| 49 |
"whose answers are explicitly supported by the prose.\n\n" |
| 50 |
"Return only a JSON list.\n\n" |
| 51 |
"Prose:\n{{ prose }}" |
| 52 |
), |
| 53 |
output_parser="json_list", |
| 54 |
), |
| 55 |
"expansion": SynthPromptTemplate( |
| 56 |
system_prompt=( |
| 57 |
"You generate high-quality single-turn instruction data from prose. " |
| 58 |
"Write curious but grounded follow-up questions a reader might ask, " |
| 59 |
"and answer them as helpfully as possible. Return a JSON list of " |
| 60 |
"objects with keys `question` and `answer`." |
| 61 |
), |
| 62 |
user_template=( |
| 63 |
"Given the prose below, generate {{ n }} question/answer pairs that " |
| 64 |
"expand on the material in a useful way without contradicting it.\n\n" |
| 65 |
"Return only a JSON list.\n\n" |
| 66 |
"Prose:\n{{ prose }}" |
| 67 |
), |
| 68 |
output_parser="json_list", |
| 69 |
), |
| 70 |
} |
| 71 |
|
| 72 |
|
| 73 |
def get_prompt_template(strategy: Literal["extraction", "expansion"]) -> SynthPromptTemplate: |
| 74 |
"""Return the shipped prompt template for one synth strategy.""" |
| 75 |
return DEFAULT_PROMPT_TEMPLATES[strategy] |