Python · 11057 bytes Raw Blame History
1 """Pure dry-run orchestration for Sprint 43 synthetic instruction generation."""
2
3 from __future__ import annotations
4
5 import json
6 import re
7 from dataclasses import dataclass
8 from datetime import UTC, datetime
9 from enum import StrEnum
10 from typing import Literal
11
12 from dlm.doc.parser import ParsedDlm
13 from dlm.doc.sections import Section, SectionType
14 from dlm.synth.prompts import PromptParserKind, SynthStrategy, get_prompt_template
15 from dlm.synth.teachers import SynthTeacher
16
17 _NUMBERED_Q_RE = re.compile(r"^\s*\d+\.\s*(?:Q|Question):\s*(.+)\s*$", re.IGNORECASE)
18 _NUMBERED_A_RE = re.compile(r"^\s*(?:A|Answer):\s*(.+)\s*$", re.IGNORECASE)
19 _FENCED_JSON_RE = re.compile(r"^\s*```(?:json)?\s*(.*?)\s*```\s*$", re.DOTALL | re.IGNORECASE)
20
21
22 class SynthSkipReason(StrEnum):
23 """Why one prose section or generated pair did not produce an addition."""
24
25 NO_PROSE = "no_prose"
26 INVALID_OUTPUT = "invalid_output"
27 EMPTY_PAIR = "empty_pair"
28 ALREADY_PRESENT = "already_present"
29
30
31 @dataclass(frozen=True)
32 class SynthPair:
33 """One generated instruction pair before materialization."""
34
35 question: str
36 answer: str
37
38
39 @dataclass(frozen=True)
40 class SynthSourceSection:
41 """One prose section selected for synthesis."""
42
43 section_id: str
44 content: str
45
46
47 ConcreteSynthStrategy = Literal["extraction", "expansion"]
48
49
50 @dataclass(frozen=True)
51 class PlannedSynthInstruction:
52 """One auto-synth instruction section ready for review/apply."""
53
54 source: SynthSourceSection
55 strategy: ConcreteSynthStrategy
56 pair: SynthPair
57 section: Section
58
59
60 @dataclass(frozen=True)
61 class SkippedSynthSection:
62 """One source section or generated pair that the dry-run declined to add."""
63
64 section_id: str
65 strategy: ConcreteSynthStrategy | None
66 reason: SynthSkipReason
67 detail: str = ""
68
69
70 @dataclass(frozen=True)
71 class SynthRunPlan:
72 """What the synth loop produced and what it declined to add."""
73
74 additions: tuple[PlannedSynthInstruction, ...]
75 skipped: tuple[SkippedSynthSection, ...]
76
77
78 def build_synth_plan(
79 parsed: ParsedDlm,
80 teacher: SynthTeacher,
81 *,
82 per_section: int = 3,
83 strategy: SynthStrategy = "extraction",
84 max_pairs: int | None = None,
85 max_new_tokens: int = 512,
86 temperature: float = 0.0,
87 top_p: float | None = None,
88 seed: int | None = None,
89 synth_at: str | None = None,
90 ) -> SynthRunPlan:
91 """Generate a dry-run synth plan from prose sections in `parsed`."""
92 if per_section < 1:
93 raise ValueError(f"per_section must be >= 1, got {per_section}")
94 if max_pairs is not None and max_pairs < 1:
95 raise ValueError(f"max_pairs must be >= 1 when set, got {max_pairs}")
96 if max_new_tokens < 1:
97 raise ValueError(f"max_new_tokens must be >= 1, got {max_new_tokens}")
98
99 additions: list[PlannedSynthInstruction] = []
100 skipped: list[SkippedSynthSection] = []
101 existing_ids = {section.section_id for section in parsed.sections}
102 effective_synth_at = synth_at if synth_at is not None else _timestamp_now()
103
104 sources = _extract_prose_sources(parsed.sections)
105 if not sources:
106 return SynthRunPlan(
107 additions=(),
108 skipped=(
109 SkippedSynthSection(
110 section_id="(document)",
111 strategy=None,
112 reason=SynthSkipReason.NO_PROSE,
113 detail="document has no non-empty PROSE sections to synthesize from",
114 ),
115 ),
116 )
117
118 for source in sources:
119 for concrete_strategy, count in _strategy_counts(strategy, per_section):
120 if count == 0:
121 continue
122
123 template = get_prompt_template(concrete_strategy)
124 rendered = teacher.generate(
125 template.system_prompt,
126 template.render_user_prompt(prose=source.content, n=count),
127 max_new_tokens=max_new_tokens,
128 temperature=temperature,
129 top_p=top_p,
130 seed=seed,
131 )
132 try:
133 pairs = _parse_generated_pairs(rendered, parser=template.output_parser)
134 except ValueError as exc:
135 skipped.append(
136 SkippedSynthSection(
137 section_id=source.section_id,
138 strategy=concrete_strategy,
139 reason=SynthSkipReason.INVALID_OUTPUT,
140 detail=str(exc),
141 )
142 )
143 continue
144
145 for pair in pairs[:count]:
146 materialized = Section(
147 type=SectionType.INSTRUCTION,
148 content=_render_instruction_body(pair),
149 auto_synth=True,
150 synth_teacher=teacher.name,
151 synth_strategy=concrete_strategy,
152 synth_at=effective_synth_at,
153 source_section_id=source.section_id,
154 )
155 if materialized.section_id in existing_ids:
156 skipped.append(
157 SkippedSynthSection(
158 section_id=source.section_id,
159 strategy=concrete_strategy,
160 reason=SynthSkipReason.ALREADY_PRESENT,
161 detail=f"section_id {materialized.section_id} already in document",
162 )
163 )
164 continue
165 additions.append(
166 PlannedSynthInstruction(
167 source=source,
168 strategy=concrete_strategy,
169 pair=pair,
170 section=materialized,
171 )
172 )
173 existing_ids.add(materialized.section_id)
174 if max_pairs is not None and len(additions) >= max_pairs:
175 return SynthRunPlan(additions=tuple(additions), skipped=tuple(skipped))
176
177 return SynthRunPlan(additions=tuple(additions), skipped=tuple(skipped))
178
179
180 def render_synth_plan(plan: SynthRunPlan) -> str:
181 """Plain-text rendering for dry-run CLI output and tests."""
182 lines = [
183 f"synth plan: {len(plan.additions)} add, {len(plan.skipped)} skip",
184 "",
185 ]
186 if plan.additions:
187 lines.append("=== additions ===")
188 for add in plan.additions:
189 lines.append("")
190 lines.append(
191 "+ ::instruction:: "
192 f"[section_id={add.section.section_id} source={add.source.section_id} "
193 f"teacher={add.section.synth_teacher} strategy={add.strategy}]"
194 )
195 lines.append(" q: " + _first_line(add.pair.question))
196 lines.append(" a: " + _first_line(add.pair.answer))
197 if plan.skipped:
198 lines.append("")
199 lines.append("=== skipped ===")
200 for skip in plan.skipped:
201 strategy = skip.strategy if skip.strategy is not None else "-"
202 lines.append(f"- {skip.section_id} [{strategy}]: {skip.reason.value} ({skip.detail})")
203 return "\n".join(lines)
204
205
206 def _extract_prose_sources(
207 sections: tuple[Section, ...] | list[Section],
208 ) -> list[SynthSourceSection]:
209 return [
210 SynthSourceSection(section_id=section.section_id, content=section.content.strip())
211 for section in sections
212 if section.type is SectionType.PROSE and section.content.strip()
213 ]
214
215
216 def _strategy_counts(
217 strategy: SynthStrategy, per_section: int
218 ) -> list[tuple[ConcreteSynthStrategy, int]]:
219 if strategy == "extraction":
220 return [("extraction", per_section)]
221 if strategy == "expansion":
222 return [("expansion", per_section)]
223 extraction = (per_section + 1) // 2
224 expansion = per_section // 2
225 return [("extraction", extraction), ("expansion", expansion)]
226
227
228 def _parse_generated_pairs(raw: str, *, parser: PromptParserKind) -> list[SynthPair]:
229 pairs = (
230 _parse_json_list_pairs(raw) if parser == "json_list" else _parse_numbered_list_pairs(raw)
231 )
232 if not pairs:
233 raise ValueError("teacher output produced no instruction pairs")
234 return pairs
235
236
237 def _parse_json_list_pairs(raw: str) -> list[SynthPair]:
238 candidate = _strip_json_fence(raw)
239 try:
240 payload = json.loads(candidate)
241 except json.JSONDecodeError as exc:
242 raise ValueError(f"teacher output is not valid JSON: {exc}") from exc
243 if not isinstance(payload, list):
244 raise ValueError("teacher output must be a JSON list")
245
246 pairs: list[SynthPair] = []
247 for idx, item in enumerate(payload):
248 if not isinstance(item, dict):
249 raise ValueError(f"teacher output item {idx} must be an object")
250 question = item.get("question")
251 answer = item.get("answer")
252 if not isinstance(question, str) or not isinstance(answer, str):
253 raise ValueError(f"teacher output item {idx} must contain string question/answer keys")
254 question_text = question.strip()
255 answer_text = answer.strip()
256 if not question_text or not answer_text:
257 raise ValueError(f"teacher output item {idx} has an empty question or answer")
258 pairs.append(SynthPair(question=question_text, answer=answer_text))
259 return pairs
260
261
262 def _strip_json_fence(raw: str) -> str:
263 match = _FENCED_JSON_RE.match(raw)
264 if match is None:
265 return raw
266 return match.group(1).strip()
267
268
269 def _parse_numbered_list_pairs(raw: str) -> list[SynthPair]:
270 lines = [line for line in raw.splitlines() if line.strip()]
271 pairs: list[SynthPair] = []
272 idx = 0
273 while idx < len(lines):
274 q_match = _NUMBERED_Q_RE.match(lines[idx])
275 if q_match is None:
276 raise ValueError(
277 "teacher output numbered_list must use lines like `1. Q: ...` or `1. Question: ...`"
278 )
279 idx += 1
280 if idx >= len(lines):
281 raise ValueError("teacher output numbered_list is missing an answer line")
282 a_match = _NUMBERED_A_RE.match(lines[idx])
283 if a_match is None:
284 raise ValueError(
285 "teacher output numbered_list answers must use `A:` or `Answer:` lines"
286 )
287 idx += 1
288 question = q_match.group(1).strip()
289 answer = a_match.group(1).strip()
290 if not question or not answer:
291 raise ValueError("teacher output numbered_list contains an empty question or answer")
292 pairs.append(SynthPair(question=question, answer=answer))
293 return pairs
294
295
296 def _render_instruction_body(pair: SynthPair) -> str:
297 return "\n".join(
298 [
299 "### Q",
300 pair.question,
301 "### A",
302 pair.answer,
303 ]
304 )
305
306
307 def _timestamp_now() -> str:
308 return datetime.now(UTC).replace(microsecond=0).isoformat().replace("+00:00", "Z")
309
310
311 def _first_line(text: str, *, max_chars: int = 80) -> str:
312 first = text.strip().splitlines()[0] if text.strip() else ""
313 if len(first) > max_chars:
314 return first[: max_chars - 1] + "…"
315 return first