| 1 |
"""Parse `### Prompt` / `### Chosen` / `### Rejected` triples from a |
| 2 |
`::preference::` section body. |
| 3 |
|
| 4 |
Grammar (strict): |
| 5 |
|
| 6 |
### Prompt |
| 7 |
<prompt body> |
| 8 |
### Chosen |
| 9 |
<chosen body> |
| 10 |
### Rejected |
| 11 |
<rejected body> |
| 12 |
(blank line) |
| 13 |
### Prompt |
| 14 |
... |
| 15 |
|
| 16 |
The three headers must appear in order (Prompt → Chosen → Rejected) for |
| 17 |
each triple. Missing, duplicated, or reordered headers raise |
| 18 |
`PreferenceParseError`. Empty field bodies are errors — DPO on empty |
| 19 |
text is never intentional. |
| 20 |
""" |
| 21 |
|
| 22 |
from __future__ import annotations |
| 23 |
|
| 24 |
from dataclasses import dataclass |
| 25 |
|
| 26 |
from dlm.data.errors import PreferenceParseError |
| 27 |
|
| 28 |
_PROMPT = "### Prompt" |
| 29 |
_CHOSEN = "### Chosen" |
| 30 |
_REJECTED = "### Rejected" |
| 31 |
_ALL_HEADERS = (_PROMPT, _CHOSEN, _REJECTED) |
| 32 |
|
| 33 |
|
| 34 |
@dataclass(frozen=True) |
| 35 |
class PreferenceTriple: |
| 36 |
"""A single preference example: prompt + chosen + rejected completion.""" |
| 37 |
|
| 38 |
prompt: str |
| 39 |
chosen: str |
| 40 |
rejected: str |
| 41 |
|
| 42 |
|
| 43 |
def parse_preference_body(body: str, *, section_id: str) -> list[PreferenceTriple]: |
| 44 |
"""Return the list of preference triples in `body`.""" |
| 45 |
lines = body.splitlines() |
| 46 |
it = _PeekableLines(lines) |
| 47 |
it.skip_blank() |
| 48 |
|
| 49 |
triples: list[PreferenceTriple] = [] |
| 50 |
while not it.eof(): |
| 51 |
triples.append(_parse_triple(it, section_id=section_id)) |
| 52 |
it.skip_blank() |
| 53 |
|
| 54 |
if not triples: |
| 55 |
raise PreferenceParseError( |
| 56 |
"preference block has no ### Prompt / ### Chosen / ### Rejected triples", |
| 57 |
section_id=section_id, |
| 58 |
section_line=1, |
| 59 |
) |
| 60 |
return triples |
| 61 |
|
| 62 |
|
| 63 |
def _parse_triple(it: _PeekableLines, *, section_id: str) -> PreferenceTriple: |
| 64 |
prompt = _parse_field(it, expected=_PROMPT, section_id=section_id) |
| 65 |
chosen = _parse_field(it, expected=_CHOSEN, section_id=section_id) |
| 66 |
rejected = _parse_field(it, expected=_REJECTED, section_id=section_id) |
| 67 |
return PreferenceTriple(prompt=prompt, chosen=chosen, rejected=rejected) |
| 68 |
|
| 69 |
|
| 70 |
def _parse_field(it: _PeekableLines, *, expected: str, section_id: str) -> str: |
| 71 |
line = it.peek_line() |
| 72 |
if line is None: |
| 73 |
raise PreferenceParseError( |
| 74 |
f"expected `{expected}` header, got end of section", |
| 75 |
section_id=section_id, |
| 76 |
section_line=it.line_no(), |
| 77 |
) |
| 78 |
if line.strip() != expected: |
| 79 |
raise PreferenceParseError( |
| 80 |
f"expected `{expected}` header alone on its line, got {line!r}", |
| 81 |
section_id=section_id, |
| 82 |
section_line=it.line_no(), |
| 83 |
) |
| 84 |
it.advance() |
| 85 |
|
| 86 |
body = _read_field_body(it) |
| 87 |
if not body: |
| 88 |
raise PreferenceParseError( |
| 89 |
f"`{expected}` body is empty", |
| 90 |
section_id=section_id, |
| 91 |
section_line=it.line_no(), |
| 92 |
) |
| 93 |
return body |
| 94 |
|
| 95 |
|
| 96 |
def _read_field_body(it: _PeekableLines) -> str: |
| 97 |
"""Read until a blank line or the next recognized header.""" |
| 98 |
buf: list[str] = [] |
| 99 |
while not it.eof(): |
| 100 |
line = it.peek_line() |
| 101 |
assert line is not None |
| 102 |
if line.strip() == "": |
| 103 |
it.advance() |
| 104 |
break |
| 105 |
if line.strip() in _ALL_HEADERS: |
| 106 |
break |
| 107 |
buf.append(line) |
| 108 |
it.advance() |
| 109 |
return "\n".join(buf).strip() |
| 110 |
|
| 111 |
|
| 112 |
class _PeekableLines: |
| 113 |
def __init__(self, lines: list[str]) -> None: |
| 114 |
self._lines = lines |
| 115 |
self._i = 0 |
| 116 |
|
| 117 |
def peek_line(self) -> str | None: |
| 118 |
if self._i >= len(self._lines): |
| 119 |
return None |
| 120 |
return self._lines[self._i] |
| 121 |
|
| 122 |
def advance(self) -> None: |
| 123 |
self._i += 1 |
| 124 |
|
| 125 |
def eof(self) -> bool: |
| 126 |
return self._i >= len(self._lines) |
| 127 |
|
| 128 |
def line_no(self) -> int: |
| 129 |
return self._i + 1 |
| 130 |
|
| 131 |
def skip_blank(self) -> None: |
| 132 |
while not self.eof(): |
| 133 |
line = self.peek_line() |
| 134 |
if line is None or line.strip() != "": |
| 135 |
return |
| 136 |
self.advance() |