Python · 2688 bytes Raw Blame History
1 """Tiny-model fixture: SmolLM2-135M-Instruct, cached and reused across tests.
2
3 Design notes:
4
5 - Sprint 06 owns the base-model registry with pinned revisions. Until that
6 lands, we accept `main` here and print a warning on first use. The CI
7 cache key includes the stored revision, so cache hits are stable.
8 - `snapshot_download` into the caller's `HF_HOME` (set by CI to a cached
9 location). Local runs default to `~/.cache/huggingface`.
10 - This module imports `huggingface_hub` lazily so test collection stays
11 fast for paths that don't need the tiny model.
12 """
13
14 from __future__ import annotations
15
16 import logging
17 import os
18 from collections.abc import Iterator
19 from pathlib import Path
20 from typing import Final
21
22 import pytest
23
24 _LOG: Final = logging.getLogger(__name__)
25
26 TINY_MODEL_HF_ID: Final = "HuggingFaceTB/SmolLM2-135M-Instruct"
27
28 # TODO(sprint-06): pin to a 40-char commit SHA via the base-model registry.
29 TINY_MODEL_REVISION: Final = os.environ.get("DLM_TINY_MODEL_REVISION", "main")
30
31
32 def tiny_model_path() -> Path:
33 """Download (if needed) the tiny model; return the cached directory path.
34
35 Raises on network failure in offline mode; callers must mark their tests
36 with `online` + `slow`.
37 """
38 # Lazy import so fast-path tests never pull huggingface_hub.
39 from huggingface_hub import snapshot_download
40
41 if TINY_MODEL_REVISION == "main":
42 _LOG.warning(
43 "TINY_MODEL_REVISION unpinned (using 'main'). Sprint 06 will pin.",
44 )
45
46 cache_dir = snapshot_download(
47 repo_id=TINY_MODEL_HF_ID,
48 revision=TINY_MODEL_REVISION,
49 local_files_only=_offline_mode(),
50 )
51 return Path(cache_dir)
52
53
54 def _offline_mode() -> bool:
55 return os.environ.get("HF_HUB_OFFLINE", "0") == "1"
56
57
58 # --- pytest fixture wrapper ---------------------------------------------------
59
60
61 @pytest.fixture(scope="session")
62 def tiny_model_dir() -> Iterator[Path]:
63 """Session-scoped fixture — download happens once per test session.
64
65 Tests that use it must carry `@pytest.mark.online` and usually
66 `@pytest.mark.slow`.
67 """
68 # Clear the autouse offline env for this fixture's scope so downloads work
69 # in tests that opted in (the `online` marker is the gate).
70 original = {
71 "HF_HUB_OFFLINE": os.environ.pop("HF_HUB_OFFLINE", None),
72 "TRANSFORMERS_OFFLINE": os.environ.pop("TRANSFORMERS_OFFLINE", None),
73 "HF_DATASETS_OFFLINE": os.environ.pop("HF_DATASETS_OFFLINE", None),
74 }
75 try:
76 yield tiny_model_path()
77 finally:
78 for k, v in original.items():
79 if v is not None:
80 os.environ[k] = v
81 else:
82 os.environ.pop(k, None)