| 1 |
"""Static mean-gate fallback for Ollama / llama.cpp export. |
| 2 |
|
| 3 |
The learned gate runs in PyTorch at `dlm prompt` time. The |
| 4 |
GGUF runtime (Ollama, llama.cpp) can't evaluate a torch module at |
| 5 |
inference, so when the user runs `dlm export` on a document with |
| 6 |
`training.gate.enabled: true` we fall back to: |
| 7 |
|
| 8 |
1. Compute the gate's softmax output on every training prompt. |
| 9 |
2. Average those probability vectors across the corpus → one fixed |
| 10 |
weight per adapter. |
| 11 |
3. Emit the averaged weights as the Modelfile's `--adapter-mix` |
| 12 |
coefficients. |
| 13 |
|
| 14 |
The exported model is a statically-weighted merge of the named |
| 15 |
adapters — lossless vs today's shipped behavior, and strictly better |
| 16 |
than asking the user to guess coefficients. Dynamic per-prompt routing |
| 17 |
is the `dlm prompt` / `dlm repl` path only. |
| 18 |
|
| 19 |
The export manifest records ``gate_mode: "static_mean"`` so downstream |
| 20 |
tooling can tell an exported-with-mean-gate build apart from a |
| 21 |
hand-picked `--adapter-mix`. |
| 22 |
""" |
| 23 |
|
| 24 |
from __future__ import annotations |
| 25 |
|
| 26 |
from dataclasses import dataclass, field |
| 27 |
from typing import TYPE_CHECKING |
| 28 |
|
| 29 |
if TYPE_CHECKING: |
| 30 |
import torch |
| 31 |
|
| 32 |
from dlm.doc.parser import ParsedDlm |
| 33 |
from dlm.store.paths import StorePath |
| 34 |
from dlm.train.gate.module import Gate, GateMetadata |
| 35 |
|
| 36 |
|
| 37 |
def mean_gate_weights( |
| 38 |
gate: Gate, |
| 39 |
metadata: GateMetadata, |
| 40 |
prompt_embeddings: list[torch.Tensor], |
| 41 |
) -> list[tuple[str, float]]: |
| 42 |
"""Average ``gate(embedding)`` across the training prompts. |
| 43 |
|
| 44 |
Returns ``[(adapter_name, weight), ...]`` suitable for direct |
| 45 |
substitution into ``dlm export --adapter-mix``. Weights sum to |
| 46 |
1.0 (gate output is softmax; average of softmax is still on the |
| 47 |
simplex) but we don't renormalize defensively — a numeric-drift |
| 48 |
renorm would mask bugs. |
| 49 |
|
| 50 |
Raises ``ValueError`` if ``prompt_embeddings`` is empty — a |
| 51 |
zero-prompt corpus has nothing to average. |
| 52 |
""" |
| 53 |
import torch |
| 54 |
|
| 55 |
if not prompt_embeddings: |
| 56 |
raise ValueError("mean_gate_weights requires >= 1 prompt embedding") |
| 57 |
|
| 58 |
with torch.no_grad(): |
| 59 |
stacked = torch.stack([e.detach().to(torch.float32).reshape(-1) for e in prompt_embeddings]) |
| 60 |
if stacked.shape[1] != metadata.input_dim: |
| 61 |
raise ValueError( |
| 62 |
f"prompt embedding dim {stacked.shape[1]} != gate input_dim " |
| 63 |
f"{metadata.input_dim} (base model mismatch?)" |
| 64 |
) |
| 65 |
probs = gate(stacked) # (N, n_adapters) |
| 66 |
mean = probs.mean(dim=0) |
| 67 |
|
| 68 |
return [(name, float(mean[i].item())) for i, name in enumerate(metadata.adapter_names)] |
| 69 |
|
| 70 |
|
| 71 |
def uniform_adapter_mix(adapter_names: tuple[str, ...]) -> list[tuple[str, float]]: |
| 72 |
"""Mean-gate fallback for uniform-mode gates (cold-start). |
| 73 |
|
| 74 |
Returns ``[(name, 1/N), ...]`` — the export path for a doc that has |
| 75 |
a gate declared but where the gate trainer chose the uniform |
| 76 |
fallback because the corpus was too small. |
| 77 |
""" |
| 78 |
n = len(adapter_names) |
| 79 |
if n == 0: |
| 80 |
return [] |
| 81 |
w = 1.0 / n |
| 82 |
return [(name, w) for name in adapter_names] |
| 83 |
|
| 84 |
|
| 85 |
def resolve_gate_mix( |
| 86 |
store: object, |
| 87 |
parsed: object, |
| 88 |
) -> list[tuple[str, float]] | None: |
| 89 |
"""Derive a static ``--adapter-mix`` from the learned gate's state. |
| 90 |
|
| 91 |
Returns ``None`` when the document has no enabled gate, declares |
| 92 |
fewer than two adapters, or the store has no persisted |
| 93 |
``gate_config.json``. Otherwise returns one of: |
| 94 |
|
| 95 |
- **uniform mode** → ``uniform_adapter_mix(adapter_names)`` |
| 96 |
- **trained mode** → the last recorded ``gate_events`` row set, |
| 97 |
mapped into ``(name, mean_weight)`` pairs. When no events have |
| 98 |
been recorded yet (e.g. gate trained but metrics DB empty) we |
| 99 |
fall back to uniform rather than refusing the export. |
| 100 |
|
| 101 |
This is the static substitution `export_cmd` uses when the user |
| 102 |
didn't pass ``--adapter-mix`` on a gate-enabled document — the |
| 103 |
Ollama / llama.cpp runtime can't evaluate the gate dynamically, |
| 104 |
so we freeze the learned prior at export time. |
| 105 |
""" |
| 106 |
import json |
| 107 |
|
| 108 |
from dlm.doc.parser import ParsedDlm |
| 109 |
from dlm.metrics import queries as _queries |
| 110 |
from dlm.store.paths import StorePath |
| 111 |
from dlm.train.gate.module import GateMetadata |
| 112 |
from dlm.train.gate.paths import gate_config_path |
| 113 |
|
| 114 |
if not isinstance(store, StorePath) or not isinstance(parsed, ParsedDlm): |
| 115 |
return None |
| 116 |
training = parsed.frontmatter.training |
| 117 |
if not training.gate.enabled: |
| 118 |
return None |
| 119 |
adapters = training.adapters |
| 120 |
if adapters is None or len(adapters) < 2: |
| 121 |
return None |
| 122 |
|
| 123 |
cfg_path = gate_config_path(store) |
| 124 |
if not cfg_path.exists(): |
| 125 |
return None |
| 126 |
raw = json.loads(cfg_path.read_text(encoding="utf-8")) |
| 127 |
meta = GateMetadata.from_json(raw) |
| 128 |
adapter_names = tuple(meta.adapter_names) |
| 129 |
|
| 130 |
if meta.mode == "uniform": |
| 131 |
return uniform_adapter_mix(adapter_names) |
| 132 |
|
| 133 |
events = _queries.latest_gate_events(store.root) |
| 134 |
if not events: |
| 135 |
return uniform_adapter_mix(adapter_names) |
| 136 |
by_name = {e.adapter_name: e.mean_weight for e in events} |
| 137 |
# Preserve declared adapter order — the Modelfile consumer reads |
| 138 |
# positionally-meaningful `--adapter-mix` tuples. |
| 139 |
return [(name, by_name.get(name, 0.0)) for name in adapter_names] |
| 140 |
|
| 141 |
|
| 142 |
@dataclass(frozen=True) |
| 143 |
class GateMixResolution: |
| 144 |
"""Result of :func:`resolve_and_announce`. |
| 145 |
|
| 146 |
`entries` is the `--adapter-mix`-shaped list of `(name, weight)` |
| 147 |
pairs, or ``None`` when no substitution applies (doc has no gate, |
| 148 |
gate not trained, or fewer than two adapters). `banner_lines` |
| 149 |
carries the pre-formatted Rich markup the CLI should print when a |
| 150 |
substitution IS made — empty on the no-substitution path. |
| 151 |
""" |
| 152 |
|
| 153 |
entries: list[tuple[str, float]] | None |
| 154 |
banner_lines: list[str] = field(default_factory=list) |
| 155 |
|
| 156 |
|
| 157 |
def resolve_and_announce(store: StorePath, parsed: ParsedDlm) -> GateMixResolution: |
| 158 |
"""Pair :func:`resolve_gate_mix` with the CLI substitution banner. |
| 159 |
|
| 160 |
Consolidates the two-step dance the CLI used to do inline: call |
| 161 |
``resolve_gate_mix`` then ``console.print`` a substitution notice |
| 162 |
on a non-``None`` result. The CLI now iterates |
| 163 |
``resolution.banner_lines`` (empty or one line) and uses |
| 164 |
``resolution.entries`` as-is — no separate print call, no |
| 165 |
duplicated substitution-condition check. |
| 166 |
""" |
| 167 |
entries = resolve_gate_mix(store, parsed) |
| 168 |
if entries is None: |
| 169 |
return GateMixResolution(entries=None) |
| 170 |
return GateMixResolution( |
| 171 |
entries=entries, |
| 172 |
banner_lines=[ |
| 173 |
"[dim]export: substituting learned gate weights for " |
| 174 |
"--adapter-mix (gate_mode=static).[/dim]" |
| 175 |
], |
| 176 |
) |