Python · 4780 bytes Raw Blame History
1 """Pinned, hashed base-model downloader.
2
3 Every base we train on flows through `download_spec()`:
4
5 1. `huggingface_hub.snapshot_download` with `revision=spec.revision`
6 fetches the exact commit. HF's own layer verifies per-file against
7 its ETag/hash index.
8 2. We read the resolved snapshot's commit SHA back (HF stores it) and
9 refuse if it doesn't match what we asked for — branch races are
10 real on non-pinned tooling.
11 3. A deterministic `sha256` over `(relative_path, file_sha256)` pairs
12 produces a single digest for the manifest, so future runs can detect
13 a base tampered with on disk.
14 """
15
16 from __future__ import annotations
17
18 import hashlib
19 import logging
20 from dataclasses import dataclass
21 from pathlib import Path
22
23 from dlm.base_models.errors import GatedModelError
24 from dlm.base_models.schema import BaseModelSpec
25
26 _LOG = logging.getLogger(__name__)
27
28
29 @dataclass(frozen=True)
30 class DownloadResult:
31 """Outcome of a successful `download_spec()` call."""
32
33 spec: BaseModelSpec
34 path: Path
35 revision: str
36 sha256: str
37
38
39 def download_spec(
40 spec: BaseModelSpec,
41 *,
42 cache_dir: Path | None = None,
43 local_dir: Path | None = None,
44 local_files_only: bool = False,
45 ) -> DownloadResult:
46 """Fetch (or locate) the snapshot for `spec` and return a pinned reference.
47
48 `cache_dir` overrides `HF_HOME`. `local_dir` copies the snapshot
49 into a specific path (non-symlinked) — used by the per-`.dlm`
50 store cache. `local_files_only=True` refuses to hit the network
51 (mirrors `HF_HUB_OFFLINE`).
52 """
53 from huggingface_hub import snapshot_download
54 from huggingface_hub.errors import (
55 GatedRepoError,
56 LocalEntryNotFoundError,
57 RepositoryNotFoundError,
58 )
59
60 try:
61 snapshot = snapshot_download(
62 repo_id=spec.hf_id,
63 revision=spec.revision,
64 cache_dir=str(cache_dir) if cache_dir else None,
65 local_dir=str(local_dir) if local_dir else None,
66 local_files_only=local_files_only,
67 )
68 except GatedRepoError as exc:
69 raise GatedModelError(spec.hf_id, spec.license_url) from exc
70 except LocalEntryNotFoundError as exc:
71 raise RuntimeError(f"{spec.hf_id} not found in local cache and offline mode is on") from exc
72 except RepositoryNotFoundError as exc:
73 raise RuntimeError(f"HF repository not found: {spec.hf_id}") from exc
74
75 path = Path(snapshot)
76 resolved_revision = _resolve_revision(path, spec.revision)
77 if resolved_revision != spec.revision:
78 raise RuntimeError(
79 f"revision mismatch for {spec.hf_id}: asked {spec.revision}, got {resolved_revision}",
80 )
81
82 digest = sha256_of_directory(path)
83 return DownloadResult(
84 spec=spec,
85 path=path,
86 revision=resolved_revision,
87 sha256=digest,
88 )
89
90
91 def sha256_of_directory(root: Path) -> str:
92 """Deterministic content digest over every file under `root`.
93
94 Input: `(posix-relative-path, sha256-of-contents)` pairs sorted by
95 path. Output: hex digest. Stable across filesystems regardless of
96 inode order and symlink resolution.
97 """
98 if not root.is_dir():
99 raise NotADirectoryError(root)
100
101 aggregator = hashlib.sha256()
102 for child in sorted(root.rglob("*")):
103 if not child.is_file():
104 continue
105 rel = child.relative_to(root).as_posix()
106 aggregator.update(rel.encode("utf-8"))
107 aggregator.update(b"\0")
108 aggregator.update(_sha256_of_file(child).encode("ascii"))
109 aggregator.update(b"\n")
110 return aggregator.hexdigest()
111
112
113 # --- internals ---------------------------------------------------------------
114
115
116 def _resolve_revision(path: Path, expected: str) -> str:
117 """Return the commit SHA HF wrote into the snapshot, falling back
118 to the expected value when HF didn't emit a `.gitcommit` marker.
119
120 `snapshot_download` uses the real SHA as the directory name under
121 `snapshots/`, so resolving from the path is authoritative; we only
122 fall back when `local_dir` copies files out of that structure.
123 """
124 # HF's canonical snapshot dir is `.../snapshots/<sha>/`. Walking up
125 # until we find that name is the cheapest check.
126 for ancestor in (path, *path.parents):
127 if ancestor.parent.name == "snapshots":
128 return ancestor.name
129 # Not inside the canonical layout (e.g., `local_dir` copy). Trust
130 # the expected SHA — snapshot_download would have already refused
131 # if the revision didn't resolve.
132 return expected
133
134
135 def _sha256_of_file(path: Path) -> str:
136 digest = hashlib.sha256()
137 with path.open("rb") as fh:
138 for chunk in iter(lambda: fh.read(1 << 20), b""):
139 digest.update(chunk)
140 return digest.hexdigest()