Python · 3823 bytes Raw Blame History
1 """Live drift checks for curated base-model registry entries."""
2
3 from __future__ import annotations
4
5 from collections.abc import Callable
6 from dataclasses import dataclass
7 from urllib.error import HTTPError, URLError
8 from urllib.request import Request, urlopen
9
10 from huggingface_hub import HfApi
11 from huggingface_hub.errors import GatedRepoError, RepositoryNotFoundError
12
13 from dlm.base_models import BASE_MODELS, BaseModelSpec
14
15 _USER_AGENT = "DocumentLanguageModel/registry-refresh"
16 FetchText = Callable[[str], str]
17
18
19 @dataclass(frozen=True)
20 class Drift:
21 """Structured diff between a local registry entry and its live sources."""
22
23 key: str
24 hf_id: str
25 fields: tuple[tuple[str, str, str], ...]
26
27 def render(self) -> str:
28 lines = [f" {self.key} ({self.hf_id})"]
29 for name, pinned, observed in self.fields:
30 lines.append(f" {name:<22} {pinned!r}{observed!r}")
31 return "\n".join(lines)
32
33
34 def fetch_text(url: str) -> str:
35 """Fetch `url` as text for provenance checks."""
36
37 req = Request(url, headers={"User-Agent": _USER_AGENT})
38 with urlopen(req, timeout=15) as resp:
39 body = bytes(resp.read())
40 charset = str(resp.headers.get_content_charset() or "utf-8")
41 return body.decode(charset, errors="replace")
42
43
44 def check_entry(
45 api: HfApi,
46 entry: BaseModelSpec,
47 *,
48 fetch_url_text: FetchText = fetch_text,
49 ) -> Drift | None:
50 """Return a structured drift report for one curated entry, if any."""
51
52 try:
53 info = api.model_info(entry.hf_id)
54 except GatedRepoError:
55 return Drift(
56 key=entry.key,
57 hf_id=entry.hf_id,
58 fields=(("gating", "readable", "now fully gated"),),
59 )
60 except RepositoryNotFoundError:
61 return Drift(
62 key=entry.key,
63 hf_id=entry.hf_id,
64 fields=(("repository", "present", "missing (renamed or deleted)"),),
65 )
66
67 drifted: list[tuple[str, str, str]] = []
68
69 current_sha = info.sha
70 if current_sha and current_sha != entry.revision:
71 drifted.append(("revision", entry.revision, current_sha))
72
73 if entry.refresh_check_hf_gating:
74 gated = getattr(info, "gated", False)
75 gated_observed = bool(gated and gated != "False")
76 if gated_observed != entry.requires_acceptance:
77 drifted.append(
78 (
79 "requires_acceptance",
80 str(entry.requires_acceptance),
81 str(gated_observed),
82 ),
83 )
84
85 if entry.provenance_url and entry.provenance_match_text:
86 expected = entry.provenance_match_text
87 try:
88 page = fetch_url_text(entry.provenance_url)
89 except (HTTPError, URLError, TimeoutError, ValueError) as exc:
90 drifted.append(
91 (
92 "provenance_url",
93 f"{entry.provenance_url} contains {expected!r}",
94 f"unreachable ({type(exc).__name__})",
95 )
96 )
97 else:
98 if expected.casefold() not in page.casefold():
99 drifted.append(
100 (
101 "provenance_marker",
102 expected,
103 f"missing from {entry.provenance_url}",
104 )
105 )
106
107 return Drift(key=entry.key, hf_id=entry.hf_id, fields=tuple(drifted)) if drifted else None
108
109
110 def check_registry(*, fetch_url_text: FetchText = fetch_text) -> list[Drift]:
111 """Check every curated entry and return drift reports."""
112
113 api = HfApi()
114 drifts: list[Drift] = []
115 for entry in BASE_MODELS.values():
116 drift = check_entry(api, entry, fetch_url_text=fetch_url_text)
117 if drift is not None:
118 drifts.append(drift)
119 return drifts