Python · 4776 bytes Raw Blame History
1 """Cache key derivation for the tokenized-section cache.
2
3 A `CacheKey` identifies one tokenization of one section under one
4 tokenizer at one sequence length. All three inputs participate in the
5 filename AND the manifest key — a change to any invalidates the entry.
6
7 `tokenizer_sha256` computes the canonical fingerprint. The fast-
8 tokenizer path reads `tokenizer.json` bytes directly; the legacy
9 SentencePiece path falls back to a deterministic dump of
10 `__getstate__()`. Both paths pin the result on the tokenizer
11 instance so repeated calls in one run cost one hash.
12 """
13
14 from __future__ import annotations
15
16 import contextlib
17 import hashlib
18 import json
19 from dataclasses import dataclass
20 from typing import TYPE_CHECKING, Any
21
22 if TYPE_CHECKING:
23 from transformers import PreTrainedTokenizerBase
24
25 _FINGERPRINT_ATTR = "_dlm_tokenizer_sha256"
26
27
28 @dataclass(frozen=True)
29 class CacheKey:
30 """Composite key for one tokenized entry.
31
32 `section_id` is 16-hex from `Section.section_id`. `tokenizer_sha`
33 is the 64-hex sha256 from `tokenizer_sha256`. `sequence_len` is
34 the tokenizer's max sequence length for this run.
35 """
36
37 section_id: str
38 tokenizer_sha: str
39 sequence_len: int
40
41 def as_filename(self) -> str:
42 """Stable shard/name for on-disk storage.
43
44 Format: `<section_id>.<tok_sha[:12]>.seq<sequence_len>.npz`.
45 The 12-char tokenizer-sha prefix (48 bits) is plenty to avoid
46 collisions within a cache: caches are per-store, typical stores
47 see O(1) tokenizer families (one pinned fingerprint per base
48 model), so the collision space is "one entry per section per
49 tokenizer family" — astronomically far from the 2^24-entry
50 birthday threshold. The full sha is persisted in the manifest
51 for verification if a collision ever occurs in practice.
52 """
53 return f"{self.section_id}.{self.tokenizer_sha[:12]}.seq{self.sequence_len}.npz"
54
55 def shard(self) -> str:
56 """First 2 hex chars of section_id — the directory shard."""
57 return self.section_id[:2]
58
59
60 def tokenizer_sha256(tokenizer: Any) -> str:
61 """Canonical sha256 of a HuggingFace tokenizer's bytes.
62
63 For fast tokenizers (`tokenizer.backend_tokenizer` is a
64 `tokenizers.Tokenizer`), serialize via `to_str()` — the canonical
65 JSON form includes vocab, merges, normalizer, pre-tokenizer,
66 post-processor, and added-tokens. A bump in any of those shifts
67 the sha and invalidates the cache, which is exactly what we want.
68
69 For legacy SentencePiece-only tokenizers (no backend), fall back
70 to a deterministic `json.dumps(sorted dict)` of the vocab + special
71 tokens + model_max_length. This is weaker but deterministic
72 enough for our usage pattern (caches are per-store, not shared).
73
74 The result is pinned on the tokenizer instance via a private
75 attribute so repeated calls in one run are O(1).
76 """
77 pinned: str | None = getattr(tokenizer, _FINGERPRINT_ATTR, None)
78 if pinned is not None:
79 return pinned
80
81 backend = getattr(tokenizer, "backend_tokenizer", None)
82 if backend is not None and hasattr(backend, "to_str"):
83 try:
84 canonical = backend.to_str()
85 except Exception: # noqa: BLE001 — defensive fallback
86 canonical = _legacy_canonical(tokenizer)
87 else:
88 canonical = _legacy_canonical(tokenizer)
89
90 sha = hashlib.sha256(canonical.encode("utf-8")).hexdigest()
91 # Some tokenizer classes forbid new attributes; fine, just re-hash
92 # next call.
93 with contextlib.suppress(AttributeError, TypeError):
94 object.__setattr__(tokenizer, _FINGERPRINT_ATTR, sha)
95 return sha
96
97
98 def _legacy_canonical(tokenizer: PreTrainedTokenizerBase) -> str:
99 """Legacy fallback: deterministic JSON dump of the tokenizer's
100 identity-bearing state.
101
102 Keys are sorted so the hash is stable across Python dict-ordering
103 quirks. Non-serializable values are stringified.
104 """
105 state: dict[str, object] = {
106 "vocab_size": getattr(tokenizer, "vocab_size", 0),
107 "model_max_length": getattr(tokenizer, "model_max_length", 0),
108 "pad_token": str(getattr(tokenizer, "pad_token", "")),
109 "eos_token": str(getattr(tokenizer, "eos_token", "")),
110 "bos_token": str(getattr(tokenizer, "bos_token", "")),
111 "unk_token": str(getattr(tokenizer, "unk_token", "")),
112 "cls_token": str(getattr(tokenizer, "cls_token", "")),
113 "sep_token": str(getattr(tokenizer, "sep_token", "")),
114 "mask_token": str(getattr(tokenizer, "mask_token", "")),
115 "added_tokens_count": len(getattr(tokenizer, "added_tokens_decoder", {}) or {}),
116 "class": tokenizer.__class__.__name__,
117 }
118 return json.dumps(state, sort_keys=True, default=str)