Python · 7930 bytes Raw Blame History
1 """Preflight checks — adapter config, tokenizer vocab, chat template."""
2
3 from __future__ import annotations
4
5 import json
6 from pathlib import Path
7 from types import SimpleNamespace
8
9 import pytest
10
11 from dlm.base_models import BASE_MODELS
12 from dlm.export.errors import PreflightError
13 from dlm.export.preflight import (
14 check_adapter_config,
15 check_chat_template,
16 check_pretokenizer_fingerprint,
17 check_tokenizer_vocab,
18 check_vl_target_modules_lm_only,
19 check_was_adapter_qlora,
20 )
21
22 _SPEC = BASE_MODELS["smollm2-135m"]
23
24
25 def _write_adapter_config(dir_: Path, **overrides: object) -> None:
26 data = {"base_model_name_or_path": _SPEC.hf_id, "peft_type": "LORA"}
27 data.update(overrides)
28 dir_.mkdir(parents=True, exist_ok=True)
29 (dir_ / "adapter_config.json").write_text(json.dumps(data))
30
31
32 def _write_tokenizer_config(dir_: Path, **overrides: object) -> None:
33 data: dict[str, object] = {"vocab_size": 32000, "chat_template": "{{messages}}"}
34 data.update(overrides)
35 dir_.mkdir(parents=True, exist_ok=True)
36 (dir_ / "tokenizer_config.json").write_text(json.dumps(data))
37
38
39 def _write_pinned_versions(dir_: Path, *, bnb: str | None) -> None:
40 data = {"torch": "2.4.0", "bitsandbytes": bnb}
41 dir_.mkdir(parents=True, exist_ok=True)
42 (dir_ / "pinned_versions.json").write_text(json.dumps(data))
43
44
45 class TestAdapterConfig:
46 def test_matching_base_ok(self, tmp_path: Path) -> None:
47 _write_adapter_config(tmp_path)
48 check_adapter_config(tmp_path, _SPEC)
49
50 def test_missing_config_raises(self, tmp_path: Path) -> None:
51 with pytest.raises(PreflightError, match="adapter_config"):
52 check_adapter_config(tmp_path, _SPEC)
53
54 def test_mismatched_base_raises(self, tmp_path: Path) -> None:
55 _write_adapter_config(tmp_path, base_model_name_or_path="other/base")
56 with pytest.raises(PreflightError, match="was trained against"):
57 check_adapter_config(tmp_path, _SPEC)
58
59 def test_malformed_json_raises(self, tmp_path: Path) -> None:
60 (tmp_path / "adapter_config.json").write_text("not json {{{")
61 with pytest.raises(PreflightError, match="cannot parse"):
62 check_adapter_config(tmp_path, _SPEC)
63
64
65 class TestTokenizerVocab:
66 def test_vocab_size_from_config(self, tmp_path: Path) -> None:
67 _write_tokenizer_config(tmp_path, vocab_size=50257)
68 assert check_tokenizer_vocab(tmp_path) == 50257
69
70 def test_fallback_to_tokenizer_json(self, tmp_path: Path) -> None:
71 _write_tokenizer_config(tmp_path)
72 # Remove vocab_size from config, plant tokenizer.json with model.vocab
73 cfg = json.loads((tmp_path / "tokenizer_config.json").read_text())
74 del cfg["vocab_size"]
75 (tmp_path / "tokenizer_config.json").write_text(json.dumps(cfg))
76 (tmp_path / "tokenizer.json").write_text(
77 json.dumps({"model": {"vocab": {str(i): i for i in range(5000)}}})
78 )
79 assert check_tokenizer_vocab(tmp_path) == 5000
80
81 def test_missing_tokenizer_config_raises(self, tmp_path: Path) -> None:
82 with pytest.raises(PreflightError, match="tokenizer metadata capture") as exc_info:
83 check_tokenizer_vocab(tmp_path)
84 assert "Sprint" not in str(exc_info.value)
85
86 def test_malformed_config_raises(self, tmp_path: Path) -> None:
87 (tmp_path / "tokenizer_config.json").write_text("not json {{{")
88 with pytest.raises(PreflightError, match="cannot parse"):
89 check_tokenizer_vocab(tmp_path)
90
91 def test_no_vocab_info_anywhere_raises(self, tmp_path: Path) -> None:
92 (tmp_path / "tokenizer_config.json").write_text(json.dumps({}))
93 with pytest.raises(PreflightError, match="cannot determine vocab"):
94 check_tokenizer_vocab(tmp_path)
95
96 def test_malformed_tokenizer_json_raises(self, tmp_path: Path) -> None:
97 (tmp_path / "tokenizer_config.json").write_text(
98 json.dumps({"chat_template": "{{messages}}"})
99 )
100 (tmp_path / "tokenizer.json").write_text("not json {{{")
101 with pytest.raises(PreflightError, match="cannot parse"):
102 check_tokenizer_vocab(tmp_path)
103
104
105 class TestChatTemplate:
106 def test_present_ok(self, tmp_path: Path) -> None:
107 _write_tokenizer_config(tmp_path, chat_template="{{messages}}")
108 check_chat_template(tmp_path)
109
110 def test_missing_raises_by_default(self, tmp_path: Path) -> None:
111 _write_tokenizer_config(tmp_path, chat_template="")
112 with pytest.raises(PreflightError, match="chat_template"):
113 check_chat_template(tmp_path)
114
115 def test_whitespace_only_template_raises(self, tmp_path: Path) -> None:
116 _write_tokenizer_config(tmp_path, chat_template=" ")
117 with pytest.raises(PreflightError):
118 check_chat_template(tmp_path)
119
120 def test_required_false_skips_check(self, tmp_path: Path) -> None:
121 # No tokenizer_config.json at all — and the check is skipped.
122 check_chat_template(tmp_path, required=False)
123
124 def test_missing_file_raises_when_required(self, tmp_path: Path) -> None:
125 with pytest.raises(PreflightError, match="missing"):
126 check_chat_template(tmp_path, required=True)
127
128 def test_malformed_config_raises(self, tmp_path: Path) -> None:
129 (tmp_path / "tokenizer_config.json").write_text("not json {{{")
130 with pytest.raises(PreflightError, match="cannot parse"):
131 check_chat_template(tmp_path, required=True)
132
133
134 class TestQloraFlag:
135 def test_missing_file_returns_false(self, tmp_path: Path) -> None:
136 assert check_was_adapter_qlora(tmp_path) is False
137
138 def test_missing_training_run_falls_back_to_pinned_versions(self, tmp_path: Path) -> None:
139 _write_pinned_versions(tmp_path, bnb="0.43.1")
140 assert check_was_adapter_qlora(tmp_path) is True
141
142 def test_true_flag_returns_true(self, tmp_path: Path) -> None:
143 (tmp_path / "training_run.json").write_text(json.dumps({"use_qlora": True}))
144 assert check_was_adapter_qlora(tmp_path) is True
145
146 def test_false_flag_returns_false(self, tmp_path: Path) -> None:
147 (tmp_path / "training_run.json").write_text(json.dumps({"use_qlora": False}))
148 assert check_was_adapter_qlora(tmp_path) is False
149
150 def test_malformed_json_raises(self, tmp_path: Path) -> None:
151 """Corrupt `training_run.json` must not silently bypass the pitfall-3 merge gate."""
152 _write_pinned_versions(tmp_path, bnb="0.43.1")
153 (tmp_path / "training_run.json").write_text("not json")
154 with pytest.raises(PreflightError, match="training_run_json"):
155 check_was_adapter_qlora(tmp_path)
156
157
158 class TestPretokenizerFingerprint:
159 def test_failed_probe_raises(self, monkeypatch: pytest.MonkeyPatch) -> None:
160 monkeypatch.setattr(
161 "dlm.base_models.probes.probe_pretokenizer_hash",
162 lambda _spec: SimpleNamespace(skipped=False, passed=False, detail="mismatch"),
163 )
164
165 with pytest.raises(PreflightError, match="pre-tokenizer fingerprint mismatch"):
166 check_pretokenizer_fingerprint(_SPEC)
167
168
169 class TestVlTargetModulesLmOnly:
170 def test_missing_config_is_noop(self, tmp_path: Path) -> None:
171 check_vl_target_modules_lm_only(tmp_path)
172
173 def test_malformed_config_is_noop(self, tmp_path: Path) -> None:
174 (tmp_path / "adapter_config.json").write_text("not json {{{")
175 check_vl_target_modules_lm_only(tmp_path)
176
177 def test_string_pattern_target_modules_is_noop(self, tmp_path: Path) -> None:
178 _write_adapter_config(tmp_path, target_modules=".*q_proj.*")
179 check_vl_target_modules_lm_only(tmp_path)
180
181 def test_vision_targets_raise(self, tmp_path: Path) -> None:
182 _write_adapter_config(tmp_path, target_modules=["q_proj", "vision_tower.block.0"])
183 with pytest.raises(PreflightError, match="vision-tower modules"):
184 check_vl_target_modules_lm_only(tmp_path)