Python · 3788 bytes Raw Blame History
1 """Parse `### Prompt` / `### Chosen` / `### Rejected` triples from a
2 `::preference::` section body.
3
4 Grammar (strict):
5
6 ### Prompt
7 <prompt body>
8 ### Chosen
9 <chosen body>
10 ### Rejected
11 <rejected body>
12 (blank line)
13 ### Prompt
14 ...
15
16 The three headers must appear in order (Prompt → Chosen → Rejected) for
17 each triple. Missing, duplicated, or reordered headers raise
18 `PreferenceParseError`. Empty field bodies are errors — DPO on empty
19 text is never intentional.
20 """
21
22 from __future__ import annotations
23
24 from dataclasses import dataclass
25
26 from dlm.data.errors import PreferenceParseError
27
28 _PROMPT = "### Prompt"
29 _CHOSEN = "### Chosen"
30 _REJECTED = "### Rejected"
31 _ALL_HEADERS = (_PROMPT, _CHOSEN, _REJECTED)
32
33
34 @dataclass(frozen=True)
35 class PreferenceTriple:
36 """A single preference example: prompt + chosen + rejected completion."""
37
38 prompt: str
39 chosen: str
40 rejected: str
41
42
43 def parse_preference_body(body: str, *, section_id: str) -> list[PreferenceTriple]:
44 """Return the list of preference triples in `body`."""
45 lines = body.splitlines()
46 it = _PeekableLines(lines)
47 it.skip_blank()
48
49 triples: list[PreferenceTriple] = []
50 while not it.eof():
51 triples.append(_parse_triple(it, section_id=section_id))
52 it.skip_blank()
53
54 if not triples:
55 raise PreferenceParseError(
56 "preference block has no ### Prompt / ### Chosen / ### Rejected triples",
57 section_id=section_id,
58 section_line=1,
59 )
60 return triples
61
62
63 def _parse_triple(it: _PeekableLines, *, section_id: str) -> PreferenceTriple:
64 prompt = _parse_field(it, expected=_PROMPT, section_id=section_id)
65 chosen = _parse_field(it, expected=_CHOSEN, section_id=section_id)
66 rejected = _parse_field(it, expected=_REJECTED, section_id=section_id)
67 return PreferenceTriple(prompt=prompt, chosen=chosen, rejected=rejected)
68
69
70 def _parse_field(it: _PeekableLines, *, expected: str, section_id: str) -> str:
71 line = it.peek_line()
72 if line is None:
73 raise PreferenceParseError(
74 f"expected `{expected}` header, got end of section",
75 section_id=section_id,
76 section_line=it.line_no(),
77 )
78 if line.strip() != expected:
79 raise PreferenceParseError(
80 f"expected `{expected}` header alone on its line, got {line!r}",
81 section_id=section_id,
82 section_line=it.line_no(),
83 )
84 it.advance()
85
86 body = _read_field_body(it)
87 if not body:
88 raise PreferenceParseError(
89 f"`{expected}` body is empty",
90 section_id=section_id,
91 section_line=it.line_no(),
92 )
93 return body
94
95
96 def _read_field_body(it: _PeekableLines) -> str:
97 """Read until a blank line or the next recognized header."""
98 buf: list[str] = []
99 while not it.eof():
100 line = it.peek_line()
101 assert line is not None
102 if line.strip() == "":
103 it.advance()
104 break
105 if line.strip() in _ALL_HEADERS:
106 break
107 buf.append(line)
108 it.advance()
109 return "\n".join(buf).strip()
110
111
112 class _PeekableLines:
113 def __init__(self, lines: list[str]) -> None:
114 self._lines = lines
115 self._i = 0
116
117 def peek_line(self) -> str | None:
118 if self._i >= len(self._lines):
119 return None
120 return self._lines[self._i]
121
122 def advance(self) -> None:
123 self._i += 1
124
125 def eof(self) -> bool:
126 return self._i >= len(self._lines)
127
128 def line_no(self) -> int:
129 return self._i + 1
130
131 def skip_blank(self) -> None:
132 while not self.eof():
133 line = self.peek_line()
134 if line is None or line.strip() != "":
135 return
136 self.advance()