Python · 3653 bytes Raw Blame History
1 """Gate nn.Module — shape + parameter count + metadata round-trip."""
2
3 from __future__ import annotations
4
5 import pytest
6
7 from dlm.train.gate import Gate, GateConfigError, GateMetadata
8
9
10 class TestGateConstruction:
11 def test_forward_shape(self) -> None:
12 import torch
13
14 gate = Gate(input_dim=128, hidden_proj_dim=32, n_adapters=3)
15 x = torch.randn(5, 128)
16 y = gate(x)
17 assert y.shape == (5, 3)
18 # Softmax rows sum to ~1.
19 assert torch.allclose(y.sum(dim=-1), torch.ones(5), atol=1e-5)
20
21 def test_parameter_count(self) -> None:
22 gate = Gate(input_dim=2048, hidden_proj_dim=64, n_adapters=4)
23 expected = (2048 * 64 + 64) + (64 * 4 + 4)
24 assert gate.num_parameters() == expected
25
26 def test_batch_dim_preserved(self) -> None:
27 import torch
28
29 gate = Gate(input_dim=16, hidden_proj_dim=8, n_adapters=2)
30 # 3D input (batch, time, features).
31 x = torch.randn(2, 7, 16)
32 y = gate(x)
33 assert y.shape == (2, 7, 2)
34
35 def test_single_adapter_refused(self) -> None:
36 with pytest.raises(ValueError, match="n_adapters must be >= 2"):
37 Gate(input_dim=8, hidden_proj_dim=4, n_adapters=1)
38
39 def test_nonpositive_dims_refused(self) -> None:
40 with pytest.raises(ValueError, match="input_dim"):
41 Gate(input_dim=0, hidden_proj_dim=4, n_adapters=2)
42 with pytest.raises(ValueError, match="hidden_proj_dim"):
43 Gate(input_dim=8, hidden_proj_dim=0, n_adapters=2)
44
45
46 class TestGateMetadataJson:
47 def test_round_trip(self) -> None:
48 meta = GateMetadata(
49 input_dim=512,
50 hidden_proj_dim=64,
51 adapter_names=("lexer", "runtime"),
52 mode="trained",
53 entropy_lambda=0.02,
54 )
55 raw = meta.to_json()
56 restored = GateMetadata.from_json(raw)
57 assert restored == meta
58
59 def test_missing_required_field(self) -> None:
60 with pytest.raises(GateConfigError, match="missing fields"):
61 GateMetadata.from_json({"input_dim": 8, "hidden_proj_dim": 4, "mode": "trained"})
62
63 def test_bad_mode(self) -> None:
64 with pytest.raises(GateConfigError, match="mode"):
65 GateMetadata.from_json(
66 {
67 "input_dim": 8,
68 "hidden_proj_dim": 4,
69 "adapter_names": ["a", "b"],
70 "mode": "bogus",
71 }
72 )
73
74 def test_adapter_names_not_list(self) -> None:
75 with pytest.raises(GateConfigError, match="adapter_names"):
76 GateMetadata.from_json(
77 {
78 "input_dim": 8,
79 "hidden_proj_dim": 4,
80 "adapter_names": "not-a-list",
81 "mode": "trained",
82 }
83 )
84
85 def test_non_integer_dims_rejected(self) -> None:
86 with pytest.raises(GateConfigError, match="input_dim/hidden_proj_dim"):
87 GateMetadata.from_json(
88 {
89 "input_dim": "8",
90 "hidden_proj_dim": 4,
91 "adapter_names": ["a", "b"],
92 "mode": "trained",
93 }
94 )
95
96 def test_non_numeric_entropy_rejected(self) -> None:
97 with pytest.raises(GateConfigError, match="entropy_lambda"):
98 GateMetadata.from_json(
99 {
100 "input_dim": 8,
101 "hidden_proj_dim": 4,
102 "adapter_names": ["a", "b"],
103 "mode": "trained",
104 "entropy_lambda": "high",
105 }
106 )