| 1 |
"""Gate nn.Module — shape + parameter count + metadata round-trip.""" |
| 2 |
|
| 3 |
from __future__ import annotations |
| 4 |
|
| 5 |
import pytest |
| 6 |
|
| 7 |
from dlm.train.gate import Gate, GateConfigError, GateMetadata |
| 8 |
|
| 9 |
|
| 10 |
class TestGateConstruction: |
| 11 |
def test_forward_shape(self) -> None: |
| 12 |
import torch |
| 13 |
|
| 14 |
gate = Gate(input_dim=128, hidden_proj_dim=32, n_adapters=3) |
| 15 |
x = torch.randn(5, 128) |
| 16 |
y = gate(x) |
| 17 |
assert y.shape == (5, 3) |
| 18 |
# Softmax rows sum to ~1. |
| 19 |
assert torch.allclose(y.sum(dim=-1), torch.ones(5), atol=1e-5) |
| 20 |
|
| 21 |
def test_parameter_count(self) -> None: |
| 22 |
gate = Gate(input_dim=2048, hidden_proj_dim=64, n_adapters=4) |
| 23 |
expected = (2048 * 64 + 64) + (64 * 4 + 4) |
| 24 |
assert gate.num_parameters() == expected |
| 25 |
|
| 26 |
def test_batch_dim_preserved(self) -> None: |
| 27 |
import torch |
| 28 |
|
| 29 |
gate = Gate(input_dim=16, hidden_proj_dim=8, n_adapters=2) |
| 30 |
# 3D input (batch, time, features). |
| 31 |
x = torch.randn(2, 7, 16) |
| 32 |
y = gate(x) |
| 33 |
assert y.shape == (2, 7, 2) |
| 34 |
|
| 35 |
def test_single_adapter_refused(self) -> None: |
| 36 |
with pytest.raises(ValueError, match="n_adapters must be >= 2"): |
| 37 |
Gate(input_dim=8, hidden_proj_dim=4, n_adapters=1) |
| 38 |
|
| 39 |
def test_nonpositive_dims_refused(self) -> None: |
| 40 |
with pytest.raises(ValueError, match="input_dim"): |
| 41 |
Gate(input_dim=0, hidden_proj_dim=4, n_adapters=2) |
| 42 |
with pytest.raises(ValueError, match="hidden_proj_dim"): |
| 43 |
Gate(input_dim=8, hidden_proj_dim=0, n_adapters=2) |
| 44 |
|
| 45 |
|
| 46 |
class TestGateMetadataJson: |
| 47 |
def test_round_trip(self) -> None: |
| 48 |
meta = GateMetadata( |
| 49 |
input_dim=512, |
| 50 |
hidden_proj_dim=64, |
| 51 |
adapter_names=("lexer", "runtime"), |
| 52 |
mode="trained", |
| 53 |
entropy_lambda=0.02, |
| 54 |
) |
| 55 |
raw = meta.to_json() |
| 56 |
restored = GateMetadata.from_json(raw) |
| 57 |
assert restored == meta |
| 58 |
|
| 59 |
def test_missing_required_field(self) -> None: |
| 60 |
with pytest.raises(GateConfigError, match="missing fields"): |
| 61 |
GateMetadata.from_json({"input_dim": 8, "hidden_proj_dim": 4, "mode": "trained"}) |
| 62 |
|
| 63 |
def test_bad_mode(self) -> None: |
| 64 |
with pytest.raises(GateConfigError, match="mode"): |
| 65 |
GateMetadata.from_json( |
| 66 |
{ |
| 67 |
"input_dim": 8, |
| 68 |
"hidden_proj_dim": 4, |
| 69 |
"adapter_names": ["a", "b"], |
| 70 |
"mode": "bogus", |
| 71 |
} |
| 72 |
) |
| 73 |
|
| 74 |
def test_adapter_names_not_list(self) -> None: |
| 75 |
with pytest.raises(GateConfigError, match="adapter_names"): |
| 76 |
GateMetadata.from_json( |
| 77 |
{ |
| 78 |
"input_dim": 8, |
| 79 |
"hidden_proj_dim": 4, |
| 80 |
"adapter_names": "not-a-list", |
| 81 |
"mode": "trained", |
| 82 |
} |
| 83 |
) |
| 84 |
|
| 85 |
def test_non_integer_dims_rejected(self) -> None: |
| 86 |
with pytest.raises(GateConfigError, match="input_dim/hidden_proj_dim"): |
| 87 |
GateMetadata.from_json( |
| 88 |
{ |
| 89 |
"input_dim": "8", |
| 90 |
"hidden_proj_dim": 4, |
| 91 |
"adapter_names": ["a", "b"], |
| 92 |
"mode": "trained", |
| 93 |
} |
| 94 |
) |
| 95 |
|
| 96 |
def test_non_numeric_entropy_rejected(self) -> None: |
| 97 |
with pytest.raises(GateConfigError, match="entropy_lambda"): |
| 98 |
GateMetadata.from_json( |
| 99 |
{ |
| 100 |
"input_dim": 8, |
| 101 |
"hidden_proj_dim": 4, |
| 102 |
"adapter_names": ["a", "b"], |
| 103 |
"mode": "trained", |
| 104 |
"entropy_lambda": "high", |
| 105 |
} |
| 106 |
) |