Python · 1510 bytes Raw Blame History
1 """Prompt template substrate for Sprint 43 synth generation."""
2
3 from __future__ import annotations
4
5 from typing import Literal, cast
6
7 import pytest
8
9 from dlm.synth import DEFAULT_PROMPT_TEMPLATES, SynthPromptTemplate, get_prompt_template
10
11
12 def test_shipped_prompt_templates_cover_both_strategies() -> None:
13 assert set(DEFAULT_PROMPT_TEMPLATES) == {"extraction", "expansion"}
14
15
16 @pytest.mark.parametrize("strategy", ["extraction", "expansion"])
17 def test_get_prompt_template_returns_shipped_template(strategy: str) -> None:
18 typed_strategy = cast(Literal["extraction", "expansion"], strategy)
19 template = get_prompt_template(typed_strategy)
20 assert template is DEFAULT_PROMPT_TEMPLATES[typed_strategy]
21 assert template.output_parser == "json_list"
22
23
24 def test_render_user_prompt_injects_required_values() -> None:
25 template = get_prompt_template("extraction")
26 rendered = template.render_user_prompt(prose="alpha beta", n=3)
27 assert "alpha beta" in rendered
28 assert "3" in rendered
29
30
31 @pytest.mark.parametrize(
32 ("template", "missing"),
33 [
34 ("Missing one variable: {{ prose }}", "['n']"),
35 ("Missing one variable: {{ n }}", "['prose']"),
36 ("Missing both variables.", "['prose', 'n']"),
37 ],
38 )
39 def test_user_template_must_reference_required_variables(
40 template: str,
41 missing: str,
42 ) -> None:
43 with pytest.raises(ValueError, match=missing):
44 SynthPromptTemplate(
45 system_prompt="hi",
46 user_template=template,
47 )