Python · 7231 bytes Raw Blame History
1 """Probe prompt derivation.
2
3 Two paths, picked in order:
4
5 1. **Explicit `!probe` markers.** Any `::instruction::` section whose
6 question header is `### Q !probe` is a user-declared probe. The
7 question body is the prompt; the answer body is the reference (not
8 enforced at eval time — generation is compared to the reference
9 only by human inspection in logs).
10 2. **Auto-sample from val split.** If no explicit probes exist, pick
11 up to `k` questions from INSTRUCTION sections via a seed-stable
12 sample. This guarantees every training run logs *something*; the
13 user graduates to explicit `!probe` markers once they know which
14 questions matter.
15
16 The emitted probes are just strings — `dlm.inference.generate` consumes
17 them with deterministic settings (temperature 0, do_sample=False) so
18 the output diff between runs is meaningful.
19 """
20
21 from __future__ import annotations
22
23 import hashlib
24 import logging
25 from dataclasses import dataclass
26
27 from dlm.data.errors import InstructionParseError
28 from dlm.data.instruction_parser import QAPair, parse_instruction_body
29 from dlm.doc.sections import Section, SectionType
30
31 _PROBE_MARKER = "!probe"
32 _PROBE_HEADER = f"### Q {_PROBE_MARKER}"
33 _LOG = logging.getLogger(__name__)
34
35
36 @dataclass(frozen=True)
37 class Probe:
38 """One probe prompt + its reference answer (for log inspection)."""
39
40 prompt: str
41 reference: str | None = None
42 section_id: str = ""
43
44
45 def extract_probes(sections: list[Section], *, k: int = 3, seed: int = 0) -> list[Probe]:
46 """Return up to `k` probes derived from `sections`.
47
48 Explicit `!probe`-marked questions take priority; if `k` explicit
49 probes are found, auto-sampling is skipped. Otherwise the remainder
50 is filled from INSTRUCTION section Q/A pairs via a deterministic
51 sample.
52 """
53 parsed_pairs = _parse_instruction_sections(sections)
54 explicit = list(_extract_explicit_probes(sections, parsed_pairs=parsed_pairs))
55 if len(explicit) >= k:
56 return explicit[:k]
57
58 needed = k - len(explicit)
59 seen_prompts = {p.prompt for p in explicit}
60 auto = _auto_sample_probes(
61 sections,
62 k=needed,
63 seed=seed,
64 exclude=seen_prompts,
65 parsed_pairs=parsed_pairs,
66 )
67 return [*explicit, *auto]
68
69
70 # --- internals ---------------------------------------------------------------
71
72
73 def _extract_explicit_probes(
74 sections: list[Section],
75 *,
76 parsed_pairs: dict[str, list[QAPair]],
77 ) -> list[Probe]:
78 """Find INSTRUCTION Q/A pairs whose question starts with `!probe`.
79
80 The `!probe` marker appears on the Q header line; the Q body is the
81 prompt text. We rewrite the body by stripping the leading `!probe`
82 token and any whitespace so the prompt itself doesn't carry the
83 marker token into model input.
84 """
85 out: list[Probe] = []
86 for section in sections:
87 if section.type is not SectionType.INSTRUCTION:
88 continue
89 pairs = parsed_pairs.get(section.section_id, [])
90 for pair in pairs:
91 # After normalization every probe pair sits in a private
92 # namespace; we flag them via a sentinel prefix in the body.
93 if pair.question.startswith(f"{_PROBE_MARKER}:"):
94 prompt = pair.question[len(_PROBE_MARKER) + 1 :].strip()
95 out.append(
96 Probe(
97 prompt=prompt,
98 reference=pair.answer,
99 section_id=section.section_id,
100 )
101 )
102 return out
103
104
105 def _normalize_probe_markers(body: str) -> str:
106 """Rewrite `### Q !probe` headers so the instruction parser accepts them.
107
108 The base parser treats anything after `### Q` as inline content and
109 rejects it. We want `!probe` as a prefix marker rather than part of
110 the grammar, so pre-process by stripping the marker off the header
111 line and planting it on the first body line with a `!probe:` prefix.
112 """
113 lines = body.splitlines()
114 rewritten: list[str] = []
115 i = 0
116 while i < len(lines):
117 line = lines[i]
118 if line.strip() == _PROBE_HEADER:
119 rewritten.append("### Q")
120 # Find the first non-blank body line and prefix it.
121 i += 1
122 while i < len(lines) and lines[i].strip() == "":
123 i += 1
124 if i < len(lines):
125 rewritten.append(f"{_PROBE_MARKER}:{lines[i]}")
126 i += 1
127 continue
128 rewritten.append(line)
129 i += 1
130 return "\n".join(rewritten)
131
132
133 def _auto_sample_probes(
134 sections: list[Section],
135 *,
136 k: int,
137 seed: int,
138 exclude: set[str],
139 parsed_pairs: dict[str, list[QAPair]],
140 ) -> list[Probe]:
141 """Deterministically pick `k` questions from INSTRUCTION sections.
142
143 Hashes `(seed, question)` and keeps the top-k by hash — a stable
144 weighted sample without needing `random.Random`. Excludes any
145 prompt already in `exclude` (typically explicit probes).
146
147 Parses the *normalized* section body so sections containing
148 `### Q !probe` headers don't trip the strict instruction parser
149 — we strip the marker, then filter out `!probe:`-prefixed bodies
150 (those are the explicit probes, which the caller has already
151 captured).
152 """
153 if k <= 0:
154 return []
155
156 candidates: list[Probe] = []
157 for section in sections:
158 if section.type is not SectionType.INSTRUCTION:
159 continue
160 pairs = parsed_pairs.get(section.section_id, [])
161 for pair in pairs:
162 # Skip explicit probes (their question body was prefixed
163 # with `!probe:` by the normalizer) — the caller handles
164 # them separately.
165 if pair.question.startswith(f"{_PROBE_MARKER}:"):
166 continue
167 if pair.question in exclude:
168 continue
169 candidates.append(
170 Probe(
171 prompt=pair.question,
172 reference=pair.answer,
173 section_id=section.section_id,
174 )
175 )
176
177 if not candidates:
178 return []
179
180 # Stable hash-based ordering.
181 keyed = sorted(candidates, key=lambda p: _probe_sort_key(p.prompt, seed))
182 return keyed[:k]
183
184
185 def _probe_sort_key(prompt: str, seed: int) -> str:
186 h = hashlib.sha256(f"{seed}\x00{prompt}".encode())
187 return h.hexdigest()
188
189
190 def _parse_instruction_sections(sections: list[Section]) -> dict[str, list[QAPair]]:
191 """Parse instruction sections once so malformed blocks warn once."""
192 parsed: dict[str, list[QAPair]] = {}
193 for section in sections:
194 if section.type is not SectionType.INSTRUCTION:
195 continue
196 try:
197 parsed[section.section_id] = parse_instruction_body(
198 _normalize_probe_markers(section.content),
199 section_id=section.section_id,
200 )
201 except InstructionParseError as exc:
202 _LOG.warning(
203 "probe extraction skipped malformed instruction section %s at line %d: %s",
204 exc.section_id,
205 exc.section_line,
206 exc,
207 )
208 parsed[section.section_id] = []
209 return parsed