| 1 |
"""Probe prompt derivation. |
| 2 |
|
| 3 |
Two paths, picked in order: |
| 4 |
|
| 5 |
1. **Explicit `!probe` markers.** Any `::instruction::` section whose |
| 6 |
question header is `### Q !probe` is a user-declared probe. The |
| 7 |
question body is the prompt; the answer body is the reference (not |
| 8 |
enforced at eval time — generation is compared to the reference |
| 9 |
only by human inspection in logs). |
| 10 |
2. **Auto-sample from val split.** If no explicit probes exist, pick |
| 11 |
up to `k` questions from INSTRUCTION sections via a seed-stable |
| 12 |
sample. This guarantees every training run logs *something*; the |
| 13 |
user graduates to explicit `!probe` markers once they know which |
| 14 |
questions matter. |
| 15 |
|
| 16 |
The emitted probes are just strings — `dlm.inference.generate` consumes |
| 17 |
them with deterministic settings (temperature 0, do_sample=False) so |
| 18 |
the output diff between runs is meaningful. |
| 19 |
""" |
| 20 |
|
| 21 |
from __future__ import annotations |
| 22 |
|
| 23 |
import hashlib |
| 24 |
import logging |
| 25 |
from dataclasses import dataclass |
| 26 |
|
| 27 |
from dlm.data.errors import InstructionParseError |
| 28 |
from dlm.data.instruction_parser import QAPair, parse_instruction_body |
| 29 |
from dlm.doc.sections import Section, SectionType |
| 30 |
|
| 31 |
_PROBE_MARKER = "!probe" |
| 32 |
_PROBE_HEADER = f"### Q {_PROBE_MARKER}" |
| 33 |
_LOG = logging.getLogger(__name__) |
| 34 |
|
| 35 |
|
| 36 |
@dataclass(frozen=True) |
| 37 |
class Probe: |
| 38 |
"""One probe prompt + its reference answer (for log inspection).""" |
| 39 |
|
| 40 |
prompt: str |
| 41 |
reference: str | None = None |
| 42 |
section_id: str = "" |
| 43 |
|
| 44 |
|
| 45 |
def extract_probes(sections: list[Section], *, k: int = 3, seed: int = 0) -> list[Probe]: |
| 46 |
"""Return up to `k` probes derived from `sections`. |
| 47 |
|
| 48 |
Explicit `!probe`-marked questions take priority; if `k` explicit |
| 49 |
probes are found, auto-sampling is skipped. Otherwise the remainder |
| 50 |
is filled from INSTRUCTION section Q/A pairs via a deterministic |
| 51 |
sample. |
| 52 |
""" |
| 53 |
parsed_pairs = _parse_instruction_sections(sections) |
| 54 |
explicit = list(_extract_explicit_probes(sections, parsed_pairs=parsed_pairs)) |
| 55 |
if len(explicit) >= k: |
| 56 |
return explicit[:k] |
| 57 |
|
| 58 |
needed = k - len(explicit) |
| 59 |
seen_prompts = {p.prompt for p in explicit} |
| 60 |
auto = _auto_sample_probes( |
| 61 |
sections, |
| 62 |
k=needed, |
| 63 |
seed=seed, |
| 64 |
exclude=seen_prompts, |
| 65 |
parsed_pairs=parsed_pairs, |
| 66 |
) |
| 67 |
return [*explicit, *auto] |
| 68 |
|
| 69 |
|
| 70 |
# --- internals --------------------------------------------------------------- |
| 71 |
|
| 72 |
|
| 73 |
def _extract_explicit_probes( |
| 74 |
sections: list[Section], |
| 75 |
*, |
| 76 |
parsed_pairs: dict[str, list[QAPair]], |
| 77 |
) -> list[Probe]: |
| 78 |
"""Find INSTRUCTION Q/A pairs whose question starts with `!probe`. |
| 79 |
|
| 80 |
The `!probe` marker appears on the Q header line; the Q body is the |
| 81 |
prompt text. We rewrite the body by stripping the leading `!probe` |
| 82 |
token and any whitespace so the prompt itself doesn't carry the |
| 83 |
marker token into model input. |
| 84 |
""" |
| 85 |
out: list[Probe] = [] |
| 86 |
for section in sections: |
| 87 |
if section.type is not SectionType.INSTRUCTION: |
| 88 |
continue |
| 89 |
pairs = parsed_pairs.get(section.section_id, []) |
| 90 |
for pair in pairs: |
| 91 |
# After normalization every probe pair sits in a private |
| 92 |
# namespace; we flag them via a sentinel prefix in the body. |
| 93 |
if pair.question.startswith(f"{_PROBE_MARKER}:"): |
| 94 |
prompt = pair.question[len(_PROBE_MARKER) + 1 :].strip() |
| 95 |
out.append( |
| 96 |
Probe( |
| 97 |
prompt=prompt, |
| 98 |
reference=pair.answer, |
| 99 |
section_id=section.section_id, |
| 100 |
) |
| 101 |
) |
| 102 |
return out |
| 103 |
|
| 104 |
|
| 105 |
def _normalize_probe_markers(body: str) -> str: |
| 106 |
"""Rewrite `### Q !probe` headers so the instruction parser accepts them. |
| 107 |
|
| 108 |
The base parser treats anything after `### Q` as inline content and |
| 109 |
rejects it. We want `!probe` as a prefix marker rather than part of |
| 110 |
the grammar, so pre-process by stripping the marker off the header |
| 111 |
line and planting it on the first body line with a `!probe:` prefix. |
| 112 |
""" |
| 113 |
lines = body.splitlines() |
| 114 |
rewritten: list[str] = [] |
| 115 |
i = 0 |
| 116 |
while i < len(lines): |
| 117 |
line = lines[i] |
| 118 |
if line.strip() == _PROBE_HEADER: |
| 119 |
rewritten.append("### Q") |
| 120 |
# Find the first non-blank body line and prefix it. |
| 121 |
i += 1 |
| 122 |
while i < len(lines) and lines[i].strip() == "": |
| 123 |
i += 1 |
| 124 |
if i < len(lines): |
| 125 |
rewritten.append(f"{_PROBE_MARKER}:{lines[i]}") |
| 126 |
i += 1 |
| 127 |
continue |
| 128 |
rewritten.append(line) |
| 129 |
i += 1 |
| 130 |
return "\n".join(rewritten) |
| 131 |
|
| 132 |
|
| 133 |
def _auto_sample_probes( |
| 134 |
sections: list[Section], |
| 135 |
*, |
| 136 |
k: int, |
| 137 |
seed: int, |
| 138 |
exclude: set[str], |
| 139 |
parsed_pairs: dict[str, list[QAPair]], |
| 140 |
) -> list[Probe]: |
| 141 |
"""Deterministically pick `k` questions from INSTRUCTION sections. |
| 142 |
|
| 143 |
Hashes `(seed, question)` and keeps the top-k by hash — a stable |
| 144 |
weighted sample without needing `random.Random`. Excludes any |
| 145 |
prompt already in `exclude` (typically explicit probes). |
| 146 |
|
| 147 |
Parses the *normalized* section body so sections containing |
| 148 |
`### Q !probe` headers don't trip the strict instruction parser |
| 149 |
— we strip the marker, then filter out `!probe:`-prefixed bodies |
| 150 |
(those are the explicit probes, which the caller has already |
| 151 |
captured). |
| 152 |
""" |
| 153 |
if k <= 0: |
| 154 |
return [] |
| 155 |
|
| 156 |
candidates: list[Probe] = [] |
| 157 |
for section in sections: |
| 158 |
if section.type is not SectionType.INSTRUCTION: |
| 159 |
continue |
| 160 |
pairs = parsed_pairs.get(section.section_id, []) |
| 161 |
for pair in pairs: |
| 162 |
# Skip explicit probes (their question body was prefixed |
| 163 |
# with `!probe:` by the normalizer) — the caller handles |
| 164 |
# them separately. |
| 165 |
if pair.question.startswith(f"{_PROBE_MARKER}:"): |
| 166 |
continue |
| 167 |
if pair.question in exclude: |
| 168 |
continue |
| 169 |
candidates.append( |
| 170 |
Probe( |
| 171 |
prompt=pair.question, |
| 172 |
reference=pair.answer, |
| 173 |
section_id=section.section_id, |
| 174 |
) |
| 175 |
) |
| 176 |
|
| 177 |
if not candidates: |
| 178 |
return [] |
| 179 |
|
| 180 |
# Stable hash-based ordering. |
| 181 |
keyed = sorted(candidates, key=lambda p: _probe_sort_key(p.prompt, seed)) |
| 182 |
return keyed[:k] |
| 183 |
|
| 184 |
|
| 185 |
def _probe_sort_key(prompt: str, seed: int) -> str: |
| 186 |
h = hashlib.sha256(f"{seed}\x00{prompt}".encode()) |
| 187 |
return h.hexdigest() |
| 188 |
|
| 189 |
|
| 190 |
def _parse_instruction_sections(sections: list[Section]) -> dict[str, list[QAPair]]: |
| 191 |
"""Parse instruction sections once so malformed blocks warn once.""" |
| 192 |
parsed: dict[str, list[QAPair]] = {} |
| 193 |
for section in sections: |
| 194 |
if section.type is not SectionType.INSTRUCTION: |
| 195 |
continue |
| 196 |
try: |
| 197 |
parsed[section.section_id] = parse_instruction_body( |
| 198 |
_normalize_probe_markers(section.content), |
| 199 |
section_id=section.section_id, |
| 200 |
) |
| 201 |
except InstructionParseError as exc: |
| 202 |
_LOG.warning( |
| 203 |
"probe extraction skipped malformed instruction section %s at line %d: %s", |
| 204 |
exc.section_id, |
| 205 |
exc.section_line, |
| 206 |
exc, |
| 207 |
) |
| 208 |
parsed[section.section_id] = [] |
| 209 |
return parsed |