| 1 |
"""Cache key derivation for the tokenized-section cache. |
| 2 |
|
| 3 |
A `CacheKey` identifies one tokenization of one section under one |
| 4 |
tokenizer at one sequence length. All three inputs participate in the |
| 5 |
filename AND the manifest key — a change to any invalidates the entry. |
| 6 |
|
| 7 |
`tokenizer_sha256` computes the canonical fingerprint. The fast- |
| 8 |
tokenizer path reads `tokenizer.json` bytes directly; the legacy |
| 9 |
SentencePiece path falls back to a deterministic dump of |
| 10 |
`__getstate__()`. Both paths pin the result on the tokenizer |
| 11 |
instance so repeated calls in one run cost one hash. |
| 12 |
""" |
| 13 |
|
| 14 |
from __future__ import annotations |
| 15 |
|
| 16 |
import contextlib |
| 17 |
import hashlib |
| 18 |
import json |
| 19 |
from dataclasses import dataclass |
| 20 |
from typing import TYPE_CHECKING, Any |
| 21 |
|
| 22 |
if TYPE_CHECKING: |
| 23 |
from transformers import PreTrainedTokenizerBase |
| 24 |
|
| 25 |
_FINGERPRINT_ATTR = "_dlm_tokenizer_sha256" |
| 26 |
|
| 27 |
|
| 28 |
@dataclass(frozen=True) |
| 29 |
class CacheKey: |
| 30 |
"""Composite key for one tokenized entry. |
| 31 |
|
| 32 |
`section_id` is 16-hex from `Section.section_id`. `tokenizer_sha` |
| 33 |
is the 64-hex sha256 from `tokenizer_sha256`. `sequence_len` is |
| 34 |
the tokenizer's max sequence length for this run. |
| 35 |
""" |
| 36 |
|
| 37 |
section_id: str |
| 38 |
tokenizer_sha: str |
| 39 |
sequence_len: int |
| 40 |
|
| 41 |
def as_filename(self) -> str: |
| 42 |
"""Stable shard/name for on-disk storage. |
| 43 |
|
| 44 |
Format: `<section_id>.<tok_sha[:12]>.seq<sequence_len>.npz`. |
| 45 |
The 12-char tokenizer-sha prefix (48 bits) is plenty to avoid |
| 46 |
collisions within a cache: caches are per-store, typical stores |
| 47 |
see O(1) tokenizer families (one pinned fingerprint per base |
| 48 |
model), so the collision space is "one entry per section per |
| 49 |
tokenizer family" — astronomically far from the 2^24-entry |
| 50 |
birthday threshold. The full sha is persisted in the manifest |
| 51 |
for verification if a collision ever occurs in practice. |
| 52 |
""" |
| 53 |
return f"{self.section_id}.{self.tokenizer_sha[:12]}.seq{self.sequence_len}.npz" |
| 54 |
|
| 55 |
def shard(self) -> str: |
| 56 |
"""First 2 hex chars of section_id — the directory shard.""" |
| 57 |
return self.section_id[:2] |
| 58 |
|
| 59 |
|
| 60 |
def tokenizer_sha256(tokenizer: Any) -> str: |
| 61 |
"""Canonical sha256 of a HuggingFace tokenizer's bytes. |
| 62 |
|
| 63 |
For fast tokenizers (`tokenizer.backend_tokenizer` is a |
| 64 |
`tokenizers.Tokenizer`), serialize via `to_str()` — the canonical |
| 65 |
JSON form includes vocab, merges, normalizer, pre-tokenizer, |
| 66 |
post-processor, and added-tokens. A bump in any of those shifts |
| 67 |
the sha and invalidates the cache, which is exactly what we want. |
| 68 |
|
| 69 |
For legacy SentencePiece-only tokenizers (no backend), fall back |
| 70 |
to a deterministic `json.dumps(sorted dict)` of the vocab + special |
| 71 |
tokens + model_max_length. This is weaker but deterministic |
| 72 |
enough for our usage pattern (caches are per-store, not shared). |
| 73 |
|
| 74 |
The result is pinned on the tokenizer instance via a private |
| 75 |
attribute so repeated calls in one run are O(1). |
| 76 |
""" |
| 77 |
pinned: str | None = getattr(tokenizer, _FINGERPRINT_ATTR, None) |
| 78 |
if pinned is not None: |
| 79 |
return pinned |
| 80 |
|
| 81 |
backend = getattr(tokenizer, "backend_tokenizer", None) |
| 82 |
if backend is not None and hasattr(backend, "to_str"): |
| 83 |
try: |
| 84 |
canonical = backend.to_str() |
| 85 |
except Exception: # noqa: BLE001 — defensive fallback |
| 86 |
canonical = _legacy_canonical(tokenizer) |
| 87 |
else: |
| 88 |
canonical = _legacy_canonical(tokenizer) |
| 89 |
|
| 90 |
sha = hashlib.sha256(canonical.encode("utf-8")).hexdigest() |
| 91 |
# Some tokenizer classes forbid new attributes; fine, just re-hash |
| 92 |
# next call. |
| 93 |
with contextlib.suppress(AttributeError, TypeError): |
| 94 |
object.__setattr__(tokenizer, _FINGERPRINT_ATTR, sha) |
| 95 |
return sha |
| 96 |
|
| 97 |
|
| 98 |
def _legacy_canonical(tokenizer: PreTrainedTokenizerBase) -> str: |
| 99 |
"""Legacy fallback: deterministic JSON dump of the tokenizer's |
| 100 |
identity-bearing state. |
| 101 |
|
| 102 |
Keys are sorted so the hash is stable across Python dict-ordering |
| 103 |
quirks. Non-serializable values are stringified. |
| 104 |
""" |
| 105 |
state: dict[str, object] = { |
| 106 |
"vocab_size": getattr(tokenizer, "vocab_size", 0), |
| 107 |
"model_max_length": getattr(tokenizer, "model_max_length", 0), |
| 108 |
"pad_token": str(getattr(tokenizer, "pad_token", "")), |
| 109 |
"eos_token": str(getattr(tokenizer, "eos_token", "")), |
| 110 |
"bos_token": str(getattr(tokenizer, "bos_token", "")), |
| 111 |
"unk_token": str(getattr(tokenizer, "unk_token", "")), |
| 112 |
"cls_token": str(getattr(tokenizer, "cls_token", "")), |
| 113 |
"sep_token": str(getattr(tokenizer, "sep_token", "")), |
| 114 |
"mask_token": str(getattr(tokenizer, "mask_token", "")), |
| 115 |
"added_tokens_count": len(getattr(tokenizer, "added_tokens_decoder", {}) or {}), |
| 116 |
"class": tokenizer.__class__.__name__, |
| 117 |
} |
| 118 |
return json.dumps(state, sort_keys=True, default=str) |