Python · 5653 bytes Raw Blame History
1 """Speculative-decoding draft registry (Sprint 12.5)."""
2
3 from __future__ import annotations
4
5 import pytest
6
7 from dlm.base_models import BASE_MODELS
8 from dlm.base_models.schema import BaseModelSpec
9 from dlm.export.draft_registry import (
10 DRAFT_PAIRS,
11 DraftPair,
12 resolve_draft,
13 validate_registry,
14 )
15
16 _QWEN_3B = BASE_MODELS["qwen2.5-3b"]
17 _SMOLLM_135M = BASE_MODELS["smollm2-135m"]
18
19
20 # --- Registry validator -----------------------------------------------------
21
22
23 class TestValidateRegistry:
24 """Every shipped pair must compile against the real BASE_MODELS."""
25
26 def test_every_pair_references_real_specs(self) -> None:
27 validate_registry(BASE_MODELS) # no raise
28
29 def test_launch_pairs_cover_expected_targets(self) -> None:
30 target_keys = {pair.target_key for pair in DRAFT_PAIRS}
31 assert "qwen2.5-3b" in target_keys
32 assert "llama-3.2-3b" in target_keys
33 assert "smollm2-1.7b" in target_keys
34
35 def test_every_pair_shares_template(self) -> None:
36 for pair in DRAFT_PAIRS:
37 target = BASE_MODELS[pair.target_key]
38 draft = BASE_MODELS[pair.draft_registry_key]
39 assert target.template == draft.template, pair
40
41 def test_every_pair_shares_tokenizer_pre(self) -> None:
42 for pair in DRAFT_PAIRS:
43 target = BASE_MODELS[pair.target_key]
44 draft = BASE_MODELS[pair.draft_registry_key]
45 assert target.tokenizer_pre == draft.tokenizer_pre, pair
46
47
48 class TestValidatorRejectsMismatches:
49 """Hand-construct a mini-registry and assert the validator catches drift."""
50
51 def _fake_spec(
52 self, *, template: str = "chatml", tokenizer_pre: str = "qwen2"
53 ) -> BaseModelSpec:
54 return BaseModelSpec.model_validate(
55 {
56 "key": "fake",
57 "hf_id": "org/fake",
58 "revision": "0" * 40,
59 "architecture": "FakeForCausalLM",
60 "params": 1_000_000_000,
61 "target_modules": ["q_proj"],
62 "template": template,
63 "gguf_arch": "fake",
64 "tokenizer_pre": tokenizer_pre,
65 "license_spdx": "MIT",
66 "requires_acceptance": False,
67 "redistributable": True,
68 "size_gb_fp16": 2.0,
69 "context_length": 2048,
70 "recommended_seq_len": 1024,
71 }
72 )
73
74 def test_missing_target_key_raises(self) -> None:
75 registry = {"only-draft": self._fake_spec()}
76 # Temporarily swap DRAFT_PAIRS by calling validate directly with a
77 # synthetic registry that's missing the target. The real
78 # DRAFT_PAIRS point at qwen2.5-3b which isn't in this registry.
79 with pytest.raises(ValueError, match="target_key 'qwen2.5-3b' not in BASE_MODELS"):
80 validate_registry(registry)
81
82 def test_missing_draft_registry_key_raises(self, monkeypatch: pytest.MonkeyPatch) -> None:
83 import dlm.export.draft_registry as mod
84
85 bad_pair = DraftPair(
86 target_key="a",
87 draft_registry_key="missing",
88 upstream_ollama_tag="a:tiny",
89 notes="missing draft key",
90 )
91 monkeypatch.setattr(mod, "DRAFT_PAIRS", (bad_pair,))
92 registry = {"a": self._fake_spec()}
93 with pytest.raises(ValueError, match="draft_registry_key 'missing' not in BASE_MODELS"):
94 validate_registry(registry)
95
96 def test_mismatched_template_raises(self, monkeypatch: pytest.MonkeyPatch) -> None:
97 import dlm.export.draft_registry as mod
98
99 bad_pair = DraftPair(
100 target_key="a",
101 draft_registry_key="b",
102 upstream_ollama_tag="a:b",
103 notes="template mismatch",
104 )
105 monkeypatch.setattr(mod, "DRAFT_PAIRS", (bad_pair,))
106 registry = {
107 "a": self._fake_spec(template="chatml"),
108 "b": self._fake_spec(template="llama3"),
109 }
110 with pytest.raises(ValueError, match="template"):
111 validate_registry(registry)
112
113 def test_mismatched_tokenizer_raises(self, monkeypatch: pytest.MonkeyPatch) -> None:
114 import dlm.export.draft_registry as mod
115
116 bad_pair = DraftPair(
117 target_key="a",
118 draft_registry_key="b",
119 upstream_ollama_tag="a:b",
120 notes="tokenizer mismatch",
121 )
122 monkeypatch.setattr(mod, "DRAFT_PAIRS", (bad_pair,))
123 registry = {
124 "a": self._fake_spec(tokenizer_pre="qwen2"),
125 "b": self._fake_spec(tokenizer_pre="llama-bpe"),
126 }
127 with pytest.raises(ValueError, match="tokenizer_pre"):
128 validate_registry(registry)
129
130
131 # --- resolve_draft ----------------------------------------------------------
132
133
134 class TestResolveDraft:
135 def test_registered_target_returns_tag(self) -> None:
136 assert resolve_draft(_QWEN_3B) == "qwen2.5:0.5b"
137
138 def test_unregistered_target_returns_none(self) -> None:
139 assert resolve_draft(_SMOLLM_135M) is None
140
141 def test_disabled_returns_none(self) -> None:
142 assert resolve_draft(_QWEN_3B, disabled=True) is None
143
144 def test_override_returns_override_verbatim(self) -> None:
145 assert resolve_draft(_QWEN_3B, override="custom:1b") == "custom:1b"
146
147 def test_override_wins_even_without_registry_match(self) -> None:
148 assert resolve_draft(_SMOLLM_135M, override="x:y") == "x:y"
149
150 def test_disabled_wins_over_override(self) -> None:
151 """--no-draft is the nuclear off switch; it beats an explicit --draft."""
152 assert resolve_draft(_QWEN_3B, disabled=True, override="should-not-appear") is None