Python · 2743 bytes Raw Blame History
1 """TRL-compatible `formatting_func` builder.
2
3 Branches per row shape:
4
5 - `"messages"` present → `tokenizer.apply_chat_template(msgs, tokenize=False)`.
6 SFTTrainer's completion-only loss masking kicks in automatically when
7 the formatted string is a chat transcript.
8 - `"text"` present → passthrough. SFTTrainer treats it as CPT (loss on
9 all tokens).
10 - neither → `DataFormatError`.
11
12 PREFERENCE rows (`prompt`/`chosen`/`rejected`) are NOT formatted here —
13 they're routed to DPOTrainer, which has its own formatter. This
14 function refuses them explicitly so an accidentally-mixed dataset
15 fails loudly at format time rather than producing silently-wrong data.
16 """
17
18 from __future__ import annotations
19
20 from collections.abc import Callable
21 from typing import TYPE_CHECKING, Any
22
23 from dlm.data.errors import DataFormatError
24
25 if TYPE_CHECKING:
26 from transformers import PreTrainedTokenizerBase
27
28 Row = dict[str, Any]
29 FormattingFunc = Callable[[Row], str]
30
31
32 def make_formatting_func(tokenizer: PreTrainedTokenizerBase) -> FormattingFunc:
33 """Return a row→str function bound to `tokenizer`'s chat template."""
34
35 def formatting_func(row: Row) -> str:
36 # HF `datasets.Dataset` unifies the schema across mixed-shape
37 # rows (e.g. PROSE + INSTRUCTION in one doc), filling missing
38 # columns with `None`. Dispatch on presence-and-non-None so a
39 # prose row with an injected `messages: None` doesn't route
40 # into `apply_chat_template` and crash Jinja with a None
41 # iterable.
42 messages = row.get("messages")
43 if messages is not None:
44 rendered = tokenizer.apply_chat_template(
45 messages,
46 tokenize=False,
47 add_generation_prompt=False,
48 )
49 if not isinstance(rendered, str):
50 raise DataFormatError(
51 f"apply_chat_template returned non-str ({type(rendered).__name__}); "
52 "ensure tokenize=False path is taken"
53 )
54 return rendered
55 text = row.get("text")
56 if text is not None:
57 if not isinstance(text, str):
58 raise DataFormatError(f"`text` field must be str, got {type(text).__name__}")
59 return text
60 if (
61 row.get("prompt") is not None
62 and row.get("chosen") is not None
63 and row.get("rejected") is not None
64 ):
65 raise DataFormatError(
66 "preference rows (prompt/chosen/rejected) must be routed to DPOTrainer, "
67 "not SFTTrainer's formatting_func"
68 )
69 raise DataFormatError(f"row has neither `messages` nor `text`: keys={sorted(row.keys())}")
70
71 return formatting_func