Python · 6604 bytes Raw Blame History
1 """Static mean-gate fallback for Ollama / llama.cpp export.
2
3 The learned gate runs in PyTorch at `dlm prompt` time. The
4 GGUF runtime (Ollama, llama.cpp) can't evaluate a torch module at
5 inference, so when the user runs `dlm export` on a document with
6 `training.gate.enabled: true` we fall back to:
7
8 1. Compute the gate's softmax output on every training prompt.
9 2. Average those probability vectors across the corpus → one fixed
10 weight per adapter.
11 3. Emit the averaged weights as the Modelfile's `--adapter-mix`
12 coefficients.
13
14 The exported model is a statically-weighted merge of the named
15 adapters — lossless vs today's shipped behavior, and strictly better
16 than asking the user to guess coefficients. Dynamic per-prompt routing
17 is the `dlm prompt` / `dlm repl` path only.
18
19 The export manifest records ``gate_mode: "static_mean"`` so downstream
20 tooling can tell an exported-with-mean-gate build apart from a
21 hand-picked `--adapter-mix`.
22 """
23
24 from __future__ import annotations
25
26 from dataclasses import dataclass, field
27 from typing import TYPE_CHECKING
28
29 if TYPE_CHECKING:
30 import torch
31
32 from dlm.doc.parser import ParsedDlm
33 from dlm.store.paths import StorePath
34 from dlm.train.gate.module import Gate, GateMetadata
35
36
37 def mean_gate_weights(
38 gate: Gate,
39 metadata: GateMetadata,
40 prompt_embeddings: list[torch.Tensor],
41 ) -> list[tuple[str, float]]:
42 """Average ``gate(embedding)`` across the training prompts.
43
44 Returns ``[(adapter_name, weight), ...]`` suitable for direct
45 substitution into ``dlm export --adapter-mix``. Weights sum to
46 1.0 (gate output is softmax; average of softmax is still on the
47 simplex) but we don't renormalize defensively — a numeric-drift
48 renorm would mask bugs.
49
50 Raises ``ValueError`` if ``prompt_embeddings`` is empty — a
51 zero-prompt corpus has nothing to average.
52 """
53 import torch
54
55 if not prompt_embeddings:
56 raise ValueError("mean_gate_weights requires >= 1 prompt embedding")
57
58 with torch.no_grad():
59 stacked = torch.stack([e.detach().to(torch.float32).reshape(-1) for e in prompt_embeddings])
60 if stacked.shape[1] != metadata.input_dim:
61 raise ValueError(
62 f"prompt embedding dim {stacked.shape[1]} != gate input_dim "
63 f"{metadata.input_dim} (base model mismatch?)"
64 )
65 probs = gate(stacked) # (N, n_adapters)
66 mean = probs.mean(dim=0)
67
68 return [(name, float(mean[i].item())) for i, name in enumerate(metadata.adapter_names)]
69
70
71 def uniform_adapter_mix(adapter_names: tuple[str, ...]) -> list[tuple[str, float]]:
72 """Mean-gate fallback for uniform-mode gates (cold-start).
73
74 Returns ``[(name, 1/N), ...]`` — the export path for a doc that has
75 a gate declared but where the gate trainer chose the uniform
76 fallback because the corpus was too small.
77 """
78 n = len(adapter_names)
79 if n == 0:
80 return []
81 w = 1.0 / n
82 return [(name, w) for name in adapter_names]
83
84
85 def resolve_gate_mix(
86 store: object,
87 parsed: object,
88 ) -> list[tuple[str, float]] | None:
89 """Derive a static ``--adapter-mix`` from the learned gate's state.
90
91 Returns ``None`` when the document has no enabled gate, declares
92 fewer than two adapters, or the store has no persisted
93 ``gate_config.json``. Otherwise returns one of:
94
95 - **uniform mode** → ``uniform_adapter_mix(adapter_names)``
96 - **trained mode** → the last recorded ``gate_events`` row set,
97 mapped into ``(name, mean_weight)`` pairs. When no events have
98 been recorded yet (e.g. gate trained but metrics DB empty) we
99 fall back to uniform rather than refusing the export.
100
101 This is the static substitution `export_cmd` uses when the user
102 didn't pass ``--adapter-mix`` on a gate-enabled document — the
103 Ollama / llama.cpp runtime can't evaluate the gate dynamically,
104 so we freeze the learned prior at export time.
105 """
106 import json
107
108 from dlm.doc.parser import ParsedDlm
109 from dlm.metrics import queries as _queries
110 from dlm.store.paths import StorePath
111 from dlm.train.gate.module import GateMetadata
112 from dlm.train.gate.paths import gate_config_path
113
114 if not isinstance(store, StorePath) or not isinstance(parsed, ParsedDlm):
115 return None
116 training = parsed.frontmatter.training
117 if not training.gate.enabled:
118 return None
119 adapters = training.adapters
120 if adapters is None or len(adapters) < 2:
121 return None
122
123 cfg_path = gate_config_path(store)
124 if not cfg_path.exists():
125 return None
126 raw = json.loads(cfg_path.read_text(encoding="utf-8"))
127 meta = GateMetadata.from_json(raw)
128 adapter_names = tuple(meta.adapter_names)
129
130 if meta.mode == "uniform":
131 return uniform_adapter_mix(adapter_names)
132
133 events = _queries.latest_gate_events(store.root)
134 if not events:
135 return uniform_adapter_mix(adapter_names)
136 by_name = {e.adapter_name: e.mean_weight for e in events}
137 # Preserve declared adapter order — the Modelfile consumer reads
138 # positionally-meaningful `--adapter-mix` tuples.
139 return [(name, by_name.get(name, 0.0)) for name in adapter_names]
140
141
142 @dataclass(frozen=True)
143 class GateMixResolution:
144 """Result of :func:`resolve_and_announce`.
145
146 `entries` is the `--adapter-mix`-shaped list of `(name, weight)`
147 pairs, or ``None`` when no substitution applies (doc has no gate,
148 gate not trained, or fewer than two adapters). `banner_lines`
149 carries the pre-formatted Rich markup the CLI should print when a
150 substitution IS made — empty on the no-substitution path.
151 """
152
153 entries: list[tuple[str, float]] | None
154 banner_lines: list[str] = field(default_factory=list)
155
156
157 def resolve_and_announce(store: StorePath, parsed: ParsedDlm) -> GateMixResolution:
158 """Pair :func:`resolve_gate_mix` with the CLI substitution banner.
159
160 Consolidates the two-step dance the CLI used to do inline: call
161 ``resolve_gate_mix`` then ``console.print`` a substitution notice
162 on a non-``None`` result. The CLI now iterates
163 ``resolution.banner_lines`` (empty or one line) and uses
164 ``resolution.entries`` as-is — no separate print call, no
165 duplicated substitution-condition check.
166 """
167 entries = resolve_gate_mix(store, parsed)
168 if entries is None:
169 return GateMixResolution(entries=None)
170 return GateMixResolution(
171 entries=entries,
172 banner_lines=[
173 "[dim]export: substituting learned gate weights for "
174 "--adapter-mix (gate_mode=static).[/dim]"
175 ],
176 )