Python · 6703 bytes Raw Blame History
1 """`resolve_gate_mix` — gate-driven static --adapter-mix substitution."""
2
3 from __future__ import annotations
4
5 import json
6 from dataclasses import replace
7 from pathlib import Path
8 from types import MappingProxyType
9
10 from dlm.doc.parser import ParsedDlm
11 from dlm.doc.schema import AdapterConfig, DlmFrontmatter, GateConfig, TrainingConfig
12 from dlm.doc.sections import Section, SectionType
13 from dlm.export.gate_fallback import resolve_and_announce, resolve_gate_mix
14 from dlm.metrics.events import GateEvent, RunStart
15 from dlm.metrics.recorder import MetricsRecorder
16 from dlm.store.paths import StorePath
17 from dlm.train.gate.module import GateMetadata
18 from dlm.train.gate.paths import gate_config_path
19
20
21 def _parsed(
22 *,
23 gate_enabled: bool = True,
24 adapters: tuple[str, ...] = ("a", "b"),
25 ) -> ParsedDlm:
26 adapter_map = {name: AdapterConfig(lora_r=4) for name in adapters} if adapters else None
27 return ParsedDlm(
28 frontmatter=DlmFrontmatter(
29 dlm_id="01HRSHWZ" + "0" * 18,
30 dlm_version=8,
31 base_model="smollm2-135m",
32 training=TrainingConfig(
33 adapters=adapter_map,
34 gate=GateConfig(enabled=gate_enabled),
35 ),
36 ),
37 sections=(
38 Section(
39 type=SectionType.PROSE,
40 content="body",
41 start_line=0,
42 adapter=None,
43 tags=MappingProxyType({}),
44 ),
45 ),
46 source_path=None,
47 )
48
49
50 def _write_gate_config(store: StorePath, meta: GateMetadata) -> None:
51 path = gate_config_path(store)
52 path.parent.mkdir(parents=True, exist_ok=True)
53 path.write_text(json.dumps(meta.to_json()), encoding="utf-8")
54
55
56 def test_gate_disabled_returns_none(tmp_path: Path) -> None:
57 store = StorePath(root=tmp_path)
58 store.ensure_layout()
59 assert resolve_gate_mix(store, _parsed(gate_enabled=False)) is None
60
61
62 def test_no_adapters_returns_none(tmp_path: Path) -> None:
63 """Schema refuses gate.enabled without >=2 adapters, so build the
64 parsed doc with `gate.enabled=False` + no adapters — the resolver
65 must still return None because the adapter map is empty."""
66 store = StorePath(root=tmp_path)
67 store.ensure_layout()
68 assert resolve_gate_mix(store, _parsed(gate_enabled=False, adapters=())) is None
69
70
71 def test_no_gate_config_returns_none(tmp_path: Path) -> None:
72 store = StorePath(root=tmp_path)
73 store.ensure_layout()
74 assert resolve_gate_mix(store, _parsed()) is None
75
76
77 def test_single_adapter_returns_none(tmp_path: Path) -> None:
78 store = StorePath(root=tmp_path)
79 store.ensure_layout()
80 parsed = _parsed(gate_enabled=False, adapters=("solo",))
81 single_adapter_training = parsed.frontmatter.training.model_copy(
82 update={"gate": GateConfig(enabled=True)}
83 )
84 single_adapter_frontmatter = parsed.frontmatter.model_copy(
85 update={"training": single_adapter_training}
86 )
87 assert resolve_gate_mix(store, replace(parsed, frontmatter=single_adapter_frontmatter)) is None
88
89
90 def test_non_store_or_non_parsed_returns_none() -> None:
91 assert resolve_gate_mix(object(), object()) is None
92
93
94 def test_uniform_mode_returns_uniform_mix(tmp_path: Path) -> None:
95 store = StorePath(root=tmp_path)
96 store.ensure_layout()
97 _write_gate_config(
98 store,
99 GateMetadata(
100 input_dim=576,
101 hidden_proj_dim=64,
102 adapter_names=("a", "b"),
103 mode="uniform",
104 ),
105 )
106 mix = resolve_gate_mix(store, _parsed())
107 assert mix == [("a", 0.5), ("b", 0.5)]
108
109
110 def test_trained_mode_uses_latest_events(tmp_path: Path) -> None:
111 store = StorePath(root=tmp_path)
112 store.ensure_layout()
113 _write_gate_config(
114 store,
115 GateMetadata(
116 input_dim=576,
117 hidden_proj_dim=64,
118 adapter_names=("a", "b"),
119 mode="trained",
120 ),
121 )
122 recorder = MetricsRecorder(store.root)
123 recorder.record_run_start(RunStart(run_id=1, adapter_version=1, phase="sft", seed=42))
124 recorder.record_gate(
125 GateEvent(run_id=1, adapter_name="a", mean_weight=0.7, sample_count=10, mode="trained")
126 )
127 recorder.record_gate(
128 GateEvent(run_id=1, adapter_name="b", mean_weight=0.3, sample_count=10, mode="trained")
129 )
130 mix = resolve_gate_mix(store, _parsed())
131 assert mix == [("a", 0.7), ("b", 0.3)]
132
133
134 def test_trained_mode_without_events_falls_back_to_uniform(tmp_path: Path) -> None:
135 store = StorePath(root=tmp_path)
136 store.ensure_layout()
137 _write_gate_config(
138 store,
139 GateMetadata(
140 input_dim=576,
141 hidden_proj_dim=64,
142 adapter_names=("a", "b"),
143 mode="trained",
144 ),
145 )
146 mix = resolve_gate_mix(store, _parsed())
147 assert mix == [("a", 0.5), ("b", 0.5)]
148
149
150 def test_preserves_declared_adapter_order(tmp_path: Path) -> None:
151 store = StorePath(root=tmp_path)
152 store.ensure_layout()
153 _write_gate_config(
154 store,
155 GateMetadata(
156 input_dim=576,
157 hidden_proj_dim=64,
158 adapter_names=("zeta", "alpha"), # on purpose: not sorted
159 mode="trained",
160 ),
161 )
162 recorder = MetricsRecorder(store.root)
163 recorder.record_run_start(RunStart(run_id=1, adapter_version=1, phase="sft", seed=42))
164 recorder.record_gate(
165 GateEvent(run_id=1, adapter_name="zeta", mean_weight=0.4, sample_count=10, mode="trained")
166 )
167 recorder.record_gate(
168 GateEvent(run_id=1, adapter_name="alpha", mean_weight=0.6, sample_count=10, mode="trained")
169 )
170 mix = resolve_gate_mix(store, _parsed(adapters=("zeta", "alpha")))
171 # Order must match the config's adapter_names tuple, not alphabetic.
172 assert mix == [("zeta", 0.4), ("alpha", 0.6)]
173
174
175 def test_resolve_and_announce_no_substitution(tmp_path: Path) -> None:
176 store = StorePath(root=tmp_path)
177 store.ensure_layout()
178
179 resolution = resolve_and_announce(store, _parsed(gate_enabled=False))
180
181 assert resolution.entries is None
182 assert resolution.banner_lines == []
183
184
185 def test_resolve_and_announce_substitution_banner(tmp_path: Path) -> None:
186 store = StorePath(root=tmp_path)
187 store.ensure_layout()
188 _write_gate_config(
189 store,
190 GateMetadata(
191 input_dim=576,
192 hidden_proj_dim=64,
193 adapter_names=("a", "b"),
194 mode="uniform",
195 ),
196 )
197
198 resolution = resolve_and_announce(store, _parsed())
199
200 assert resolution.entries == [("a", 0.5), ("b", 0.5)]
201 assert resolution.banner_lines == [
202 "[dim]export: substituting learned gate weights for --adapter-mix (gate_mode=static).[/dim]"
203 ]