Python · 3488 bytes Raw Blame History
1 """Tokenizer load + fixup.
2
3 Three invariants enforced here (see CLAUDE.md pitfall #4):
4
5 1. **pad_token != eos_token.** HF defaults `pad_token` to `eos_token`
6 on most bases; if `pad_token is None`, we MUST pick a different
7 token, or labels get corrupted by mid-sequence EOS masking.
8 Fallback order: `unk_token` → else add `<|pad|>` as a new special
9 token (which grows the vocab and sets `tokenizer_grew=True` for
10 the caller to propagate into the LoRA config).
11 2. **chat_template must be present.** Without it, SFTTrainer can't
12 render `messages`-shaped rows. We surface a typed
13 `TokenizerBringupError` rather than letting SFT fail deep inside
14 TRL with an opaque message.
15 3. **Revision pinning.** Every load goes through the base model's
16 40-char revision SHA — never a branch — so retrains under the
17 same spec reproduce.
18
19 Returns a `TokenizerBringup` dataclass rather than a bare tokenizer so
20 the `tokenizer_grew` flag travels with the object.
21 """
22
23 from __future__ import annotations
24
25 from dataclasses import dataclass
26 from typing import TYPE_CHECKING, Any
27
28 from dlm.data.errors import TokenizerBringupError
29
30 if TYPE_CHECKING:
31 from transformers import PreTrainedTokenizerBase
32
33 _PAD_TOKEN_LITERAL = "<|pad|>"
34
35
36 @dataclass(frozen=True)
37 class TokenizerBringup:
38 """Result of `prepare_tokenizer`.
39
40 `tokenizer_grew=True` means a new `<|pad|>` token was added to the
41 vocab. The LoRA config MUST set
42 `modules_to_save=["embed_tokens","lm_head"]` in that case —
43 otherwise the new embedding row will not be trained and its
44 output distribution is undefined.
45 """
46
47 tokenizer: PreTrainedTokenizerBase
48 tokenizer_grew: bool
49 pad_token: str
50 chat_template: str
51
52
53 def prepare_tokenizer(hf_id: str, revision: str) -> TokenizerBringup:
54 """Load the tokenizer for `hf_id` at `revision`, apply pad/template fixups."""
55 from transformers import AutoTokenizer
56
57 tok: Any = AutoTokenizer.from_pretrained(hf_id, revision=revision, use_fast=True)
58 grew = _ensure_pad_token(tok)
59 _ensure_chat_template(tok, hf_id=hf_id)
60
61 pad = tok.pad_token
62 chat_template = tok.chat_template
63 assert isinstance(pad, str)
64 assert isinstance(chat_template, str)
65
66 return TokenizerBringup(
67 tokenizer=tok,
68 tokenizer_grew=grew,
69 pad_token=pad,
70 chat_template=chat_template,
71 )
72
73
74 def _ensure_pad_token(tok: Any) -> bool:
75 """Guarantee `tok.pad_token` is set AND distinct from `tok.eos_token`.
76
77 Returns True iff a new special token was added to the vocab.
78 """
79 eos = getattr(tok, "eos_token", None)
80 current_pad = getattr(tok, "pad_token", None)
81
82 if current_pad is not None and current_pad != eos:
83 return False
84
85 # Either pad is unset, or it equals eos (the HF default we must override).
86 unk = getattr(tok, "unk_token", None)
87 if unk is not None and unk != eos:
88 tok.pad_token = unk
89 return False
90
91 # Last resort: add a new pad token. This grows the vocab, which
92 # forces training to update embed_tokens + lm_head.
93 tok.add_special_tokens({"pad_token": _PAD_TOKEN_LITERAL})
94 return True
95
96
97 def _ensure_chat_template(tok: Any, *, hf_id: str) -> None:
98 template = getattr(tok, "chat_template", None)
99 if template is None or not str(template).strip():
100 raise TokenizerBringupError(
101 f"base model {hf_id!r} has no chat_template; "
102 "supply one via --chat-template or pick a registry base"
103 )