| 1 |
"""Pinned, hashed base-model downloader. |
| 2 |
|
| 3 |
Every base we train on flows through `download_spec()`: |
| 4 |
|
| 5 |
1. `huggingface_hub.snapshot_download` with `revision=spec.revision` |
| 6 |
fetches the exact commit. HF's own layer verifies per-file against |
| 7 |
its ETag/hash index. |
| 8 |
2. We read the resolved snapshot's commit SHA back (HF stores it) and |
| 9 |
refuse if it doesn't match what we asked for — branch races are |
| 10 |
real on non-pinned tooling. |
| 11 |
3. A deterministic `sha256` over `(relative_path, file_sha256)` pairs |
| 12 |
produces a single digest for the manifest, so future runs can detect |
| 13 |
a base tampered with on disk. |
| 14 |
""" |
| 15 |
|
| 16 |
from __future__ import annotations |
| 17 |
|
| 18 |
import hashlib |
| 19 |
import logging |
| 20 |
from dataclasses import dataclass |
| 21 |
from pathlib import Path |
| 22 |
|
| 23 |
from dlm.base_models.errors import GatedModelError |
| 24 |
from dlm.base_models.schema import BaseModelSpec |
| 25 |
|
| 26 |
_LOG = logging.getLogger(__name__) |
| 27 |
|
| 28 |
|
| 29 |
@dataclass(frozen=True) |
| 30 |
class DownloadResult: |
| 31 |
"""Outcome of a successful `download_spec()` call.""" |
| 32 |
|
| 33 |
spec: BaseModelSpec |
| 34 |
path: Path |
| 35 |
revision: str |
| 36 |
sha256: str |
| 37 |
|
| 38 |
|
| 39 |
def download_spec( |
| 40 |
spec: BaseModelSpec, |
| 41 |
*, |
| 42 |
cache_dir: Path | None = None, |
| 43 |
local_dir: Path | None = None, |
| 44 |
local_files_only: bool = False, |
| 45 |
) -> DownloadResult: |
| 46 |
"""Fetch (or locate) the snapshot for `spec` and return a pinned reference. |
| 47 |
|
| 48 |
`cache_dir` overrides `HF_HOME`. `local_dir` copies the snapshot |
| 49 |
into a specific path (non-symlinked) — used by the per-`.dlm` |
| 50 |
store cache. `local_files_only=True` refuses to hit the network |
| 51 |
(mirrors `HF_HUB_OFFLINE`). |
| 52 |
""" |
| 53 |
from huggingface_hub import snapshot_download |
| 54 |
from huggingface_hub.errors import ( |
| 55 |
GatedRepoError, |
| 56 |
LocalEntryNotFoundError, |
| 57 |
RepositoryNotFoundError, |
| 58 |
) |
| 59 |
|
| 60 |
try: |
| 61 |
snapshot = snapshot_download( |
| 62 |
repo_id=spec.hf_id, |
| 63 |
revision=spec.revision, |
| 64 |
cache_dir=str(cache_dir) if cache_dir else None, |
| 65 |
local_dir=str(local_dir) if local_dir else None, |
| 66 |
local_files_only=local_files_only, |
| 67 |
) |
| 68 |
except GatedRepoError as exc: |
| 69 |
raise GatedModelError(spec.hf_id, spec.license_url) from exc |
| 70 |
except LocalEntryNotFoundError as exc: |
| 71 |
raise RuntimeError(f"{spec.hf_id} not found in local cache and offline mode is on") from exc |
| 72 |
except RepositoryNotFoundError as exc: |
| 73 |
raise RuntimeError(f"HF repository not found: {spec.hf_id}") from exc |
| 74 |
|
| 75 |
path = Path(snapshot) |
| 76 |
resolved_revision = _resolve_revision(path, spec.revision) |
| 77 |
if resolved_revision != spec.revision: |
| 78 |
raise RuntimeError( |
| 79 |
f"revision mismatch for {spec.hf_id}: asked {spec.revision}, got {resolved_revision}", |
| 80 |
) |
| 81 |
|
| 82 |
digest = sha256_of_directory(path) |
| 83 |
return DownloadResult( |
| 84 |
spec=spec, |
| 85 |
path=path, |
| 86 |
revision=resolved_revision, |
| 87 |
sha256=digest, |
| 88 |
) |
| 89 |
|
| 90 |
|
| 91 |
def sha256_of_directory(root: Path) -> str: |
| 92 |
"""Deterministic content digest over every file under `root`. |
| 93 |
|
| 94 |
Input: `(posix-relative-path, sha256-of-contents)` pairs sorted by |
| 95 |
path. Output: hex digest. Stable across filesystems regardless of |
| 96 |
inode order and symlink resolution. |
| 97 |
""" |
| 98 |
if not root.is_dir(): |
| 99 |
raise NotADirectoryError(root) |
| 100 |
|
| 101 |
aggregator = hashlib.sha256() |
| 102 |
for child in sorted(root.rglob("*")): |
| 103 |
if not child.is_file(): |
| 104 |
continue |
| 105 |
rel = child.relative_to(root).as_posix() |
| 106 |
aggregator.update(rel.encode("utf-8")) |
| 107 |
aggregator.update(b"\0") |
| 108 |
aggregator.update(_sha256_of_file(child).encode("ascii")) |
| 109 |
aggregator.update(b"\n") |
| 110 |
return aggregator.hexdigest() |
| 111 |
|
| 112 |
|
| 113 |
# --- internals --------------------------------------------------------------- |
| 114 |
|
| 115 |
|
| 116 |
def _resolve_revision(path: Path, expected: str) -> str: |
| 117 |
"""Return the commit SHA HF wrote into the snapshot, falling back |
| 118 |
to the expected value when HF didn't emit a `.gitcommit` marker. |
| 119 |
|
| 120 |
`snapshot_download` uses the real SHA as the directory name under |
| 121 |
`snapshots/`, so resolving from the path is authoritative; we only |
| 122 |
fall back when `local_dir` copies files out of that structure. |
| 123 |
""" |
| 124 |
# HF's canonical snapshot dir is `.../snapshots/<sha>/`. Walking up |
| 125 |
# until we find that name is the cheapest check. |
| 126 |
for ancestor in (path, *path.parents): |
| 127 |
if ancestor.parent.name == "snapshots": |
| 128 |
return ancestor.name |
| 129 |
# Not inside the canonical layout (e.g., `local_dir` copy). Trust |
| 130 |
# the expected SHA — snapshot_download would have already refused |
| 131 |
# if the revision didn't resolve. |
| 132 |
return expected |
| 133 |
|
| 134 |
|
| 135 |
def _sha256_of_file(path: Path) -> str: |
| 136 |
digest = hashlib.sha256() |
| 137 |
with path.open("rb") as fh: |
| 138 |
for chunk in iter(lambda: fh.read(1 << 20), b""): |
| 139 |
digest.update(chunk) |
| 140 |
return digest.hexdigest() |