Python · 2377 bytes Raw Blame History
1 """Canonical tokenizer-vocabulary-extension contract.
2
3 A training run whose bringup adds a new special token grows the
4 vocabulary. Every downstream stage — LoRA config
5 (`modules_to_save`), export preflight
6 (`tokenizer_from_adapter_dir.vocab_size == gguf_base.vocab_size +
7 N_added`), Modelfile stops — depends on *the same* predicate for
8 "did this tokenizer grow". This module is that predicate's canonical
9 home.
10
11 Two functions:
12
13 - `tokenizer_grew(base, final)` — True iff `vocab_size` changed or the
14 added-token set changed. Works for any `PreTrainedTokenizerBase`
15 (BPE or SentencePiece family).
16 - `modules_to_save_for_growth(grew)` — `["embed_tokens", "lm_head"]`
17 when `grew=True`, else `[]`. Training calls this when building the
18 LoRA config. Per pitfall #4, without the modules_to_save entry the
19 new embedding row's output is undefined.
20
21 The predicate lives in `dlm.data` because the tokenizer itself is a
22 data-stage concern; the downstream callers (`dlm.train.adapter`,
23 `dlm.export.preflight`) import from here rather than reimplementing.
24 """
25
26 from __future__ import annotations
27
28 from typing import TYPE_CHECKING
29
30 if TYPE_CHECKING:
31 from transformers import PreTrainedTokenizerBase
32
33
34 def tokenizer_grew(base: PreTrainedTokenizerBase, final: PreTrainedTokenizerBase) -> bool:
35 """True iff `final` has a larger vocab or different added-token set than `base`.
36
37 `vocab_size` comparison catches the `add_special_tokens` path used by
38 the pad fallback. The `get_added_vocab()` set-comparison
39 catches cases where an added token was *replaced* with a same-count
40 variant (vocab size unchanged but contents differ) — rare but
41 possible when users manually mutate the tokenizer between runs.
42 """
43 if final.vocab_size != base.vocab_size:
44 return True
45 base_added = set(base.get_added_vocab())
46 final_added = set(final.get_added_vocab())
47 return base_added != final_added
48
49
50 def modules_to_save_for_growth(grew: bool) -> list[str]:
51 """Return the LoRA `modules_to_save` list for a grown tokenizer.
52
53 `["embed_tokens", "lm_head"]` forces the LoRA config to train the
54 input embeddings and the output projection alongside the adapter
55 ranks, so the new token(s) have meaningful outputs. `[]` on an
56 unchanged tokenizer keeps the adapter small.
57 """
58 return ["embed_tokens", "lm_head"] if grew else []