tenseleyflow/documentlanguagemodel / bd96149

Browse files

feat(export.gate_fallback): mean-gate static adapter-mix for Ollama/GGUF

The GGUF runtime can't evaluate the learned torch gate at inference,
so export averages the gate's softmax output across training prompts
and emits the averaged weights as --adapter-mix coefficients.
Lossless vs today's shipped behavior. The export manifest will record
gate_mode='static_mean' so downstream tooling can tell these builds
apart from hand-picked mixes. Uniform-mode gates (cold-start) use the
corresponding uniform 1/N split directly.
Authored by espadonne
SHA
bd96149bdc74160d746c55c9edb7fd07d05bff6c
Parents
cbcdb09
Tree
241ea2e

2 changed files

StatusFile+-
A src/dlm/export/gate_fallback.py 79 0
A tests/unit/export/test_gate_fallback.py 81 0
src/dlm/export/gate_fallback.pyadded
@@ -0,0 +1,79 @@
1
+"""Static mean-gate fallback for Ollama / llama.cpp export.
2
+
3
+The learned gate (Sprint 34) runs in PyTorch at `dlm prompt` time. The
4
+GGUF runtime (Ollama, llama.cpp) can't evaluate a torch module at
5
+inference, so when the user runs `dlm export` on a document with
6
+`training.gate.enabled: true` we fall back to:
7
+
8
+1. Compute the gate's softmax output on every training prompt.
9
+2. Average those probability vectors across the corpus → one fixed
10
+   weight per adapter.
11
+3. Emit the averaged weights as the Modelfile's `--adapter-mix`
12
+   coefficients.
13
+
14
+The exported model is a statically-weighted merge of the named
15
+adapters — lossless vs today's shipped behavior, and strictly better
16
+than asking the user to guess coefficients. Dynamic per-prompt routing
17
+is the `dlm prompt` / `dlm repl` path only.
18
+
19
+The export manifest records ``gate_mode: "static_mean"`` so downstream
20
+tooling can tell an exported-with-mean-gate build apart from a
21
+hand-picked `--adapter-mix`.
22
+"""
23
+
24
+from __future__ import annotations
25
+
26
+from typing import TYPE_CHECKING
27
+
28
+if TYPE_CHECKING:
29
+    import torch
30
+
31
+    from dlm.train.gate.module import Gate, GateMetadata
32
+
33
+
34
+def mean_gate_weights(
35
+    gate: Gate,
36
+    metadata: GateMetadata,
37
+    prompt_embeddings: list[torch.Tensor],
38
+) -> list[tuple[str, float]]:
39
+    """Average ``gate(embedding)`` across the training prompts.
40
+
41
+    Returns ``[(adapter_name, weight), ...]`` suitable for direct
42
+    substitution into ``dlm export --adapter-mix``. Weights sum to
43
+    1.0 (gate output is softmax; average of softmax is still on the
44
+    simplex) but we don't renormalize defensively — a numeric-drift
45
+    renorm would mask bugs.
46
+
47
+    Raises ``ValueError`` if ``prompt_embeddings`` is empty — a
48
+    zero-prompt corpus has nothing to average.
49
+    """
50
+    import torch
51
+
52
+    if not prompt_embeddings:
53
+        raise ValueError("mean_gate_weights requires >= 1 prompt embedding")
54
+
55
+    with torch.no_grad():
56
+        stacked = torch.stack([e.detach().to(torch.float32).reshape(-1) for e in prompt_embeddings])
57
+        if stacked.shape[1] != metadata.input_dim:
58
+            raise ValueError(
59
+                f"prompt embedding dim {stacked.shape[1]} != gate input_dim "
60
+                f"{metadata.input_dim} (base model mismatch?)"
61
+            )
62
+        probs = gate(stacked)  # (N, n_adapters)
63
+        mean = probs.mean(dim=0)
64
+
65
+    return [(name, float(mean[i].item())) for i, name in enumerate(metadata.adapter_names)]
66
+
67
+
68
+def uniform_adapter_mix(adapter_names: tuple[str, ...]) -> list[tuple[str, float]]:
69
+    """Mean-gate fallback for uniform-mode gates (cold-start).
70
+
71
+    Returns ``[(name, 1/N), ...]`` — the export path for a doc that has
72
+    a gate declared but where the gate trainer chose the uniform
73
+    fallback because the corpus was too small.
74
+    """
75
+    n = len(adapter_names)
76
+    if n == 0:
77
+        return []
78
+    w = 1.0 / n
79
+    return [(name, w) for name in adapter_names]
tests/unit/export/test_gate_fallback.pyadded
@@ -0,0 +1,81 @@
1
+"""Static mean-gate fallback for Ollama export."""
2
+
3
+from __future__ import annotations
4
+
5
+import pytest
6
+
7
+from dlm.export.gate_fallback import mean_gate_weights, uniform_adapter_mix
8
+from dlm.train.gate.module import Gate, GateMetadata
9
+
10
+
11
+class TestUniformAdapterMix:
12
+    def test_three_adapters_third_each(self) -> None:
13
+        mix = uniform_adapter_mix(("a", "b", "c"))
14
+        assert mix == [("a", 1 / 3), ("b", 1 / 3), ("c", 1 / 3)]
15
+
16
+    def test_empty_tuple(self) -> None:
17
+        assert uniform_adapter_mix(()) == []
18
+
19
+
20
+class TestMeanGateWeights:
21
+    def _gate_and_meta(self) -> tuple[Gate, GateMetadata]:
22
+        gate = Gate(input_dim=8, hidden_proj_dim=4, n_adapters=2)
23
+        meta = GateMetadata(
24
+            input_dim=8,
25
+            hidden_proj_dim=4,
26
+            adapter_names=("a", "b"),
27
+            mode="trained",
28
+        )
29
+        return gate, meta
30
+
31
+    def test_empty_corpus_refused(self) -> None:
32
+        gate, meta = self._gate_and_meta()
33
+        with pytest.raises(ValueError, match=">= 1 prompt embedding"):
34
+            mean_gate_weights(gate, meta, [])
35
+
36
+    def test_weights_shape_and_sum_to_one(self) -> None:
37
+        import torch
38
+
39
+        gate, meta = self._gate_and_meta()
40
+        embeddings = [torch.randn(8) for _ in range(16)]
41
+        mix = mean_gate_weights(gate, meta, embeddings)
42
+        assert [name for name, _ in mix] == ["a", "b"]
43
+        total = sum(w for _, w in mix)
44
+        assert total == pytest.approx(1.0, abs=1e-5)
45
+        for _, w in mix:
46
+            assert 0.0 <= w <= 1.0
47
+
48
+    def test_dim_mismatch_refused(self) -> None:
49
+        import torch
50
+
51
+        gate, meta = self._gate_and_meta()
52
+        # Wrong-dim embedding.
53
+        with pytest.raises(ValueError, match="input_dim"):
54
+            mean_gate_weights(gate, meta, [torch.randn(4)])
55
+
56
+    def test_mean_reflects_per_prompt_skew(self) -> None:
57
+        """Ten prompts near cluster A + one prompt near cluster B should
58
+        average out to favor A. Sanity check that mean_gate_weights isn't
59
+        just emitting uniform."""
60
+        import torch
61
+
62
+        gate = Gate(input_dim=8, hidden_proj_dim=8, n_adapters=2)
63
+        meta = GateMetadata(
64
+            input_dim=8,
65
+            hidden_proj_dim=8,
66
+            adapter_names=("a", "b"),
67
+            mode="trained",
68
+        )
69
+        # Force the gate weights so it's (almost) deterministic: class-a
70
+        # embeddings near +1, class-b near -1.
71
+        torch.manual_seed(0)
72
+        a_embeddings = [torch.ones(8) + 0.01 * torch.randn(8) for _ in range(10)]
73
+        b_embedding = -torch.ones(8)
74
+        # We won't train here — untrained gate may or may not favor A.
75
+        # The point is only that the mean is a real average (not uniform
76
+        # or fixed), which we check by comparing against a single-prompt
77
+        # case.
78
+        mix_mixed = mean_gate_weights(gate, meta, a_embeddings + [b_embedding])
79
+        mix_single_a = mean_gate_weights(gate, meta, [a_embeddings[0]])
80
+        # Different input distributions → different averaged outputs.
81
+        assert mix_mixed != mix_single_a