Python · 6044 bytes Raw Blame History
1 """Tests for :mod:`dlm_sway.visualize`.
2
3 Exercises the error path (matplotlib missing) and the happy path when
4 the module is present by stubbing ``matplotlib.pyplot`` via sys.modules.
5 """
6
7 from __future__ import annotations
8
9 import sys
10 import types
11 from datetime import timedelta
12
13 import pytest
14
15 from dlm_sway.core.errors import BackendNotAvailableError
16 from dlm_sway.core.result import ProbeResult, SuiteResult, Verdict, utcnow
17
18
19 def _suite_with(*probes: ProbeResult) -> SuiteResult:
20 started = utcnow()
21 return SuiteResult(
22 spec_path="sway.yaml",
23 started_at=started,
24 finished_at=started + timedelta(seconds=1),
25 base_model_id="b",
26 adapter_id="a",
27 sway_version="0.1.0.dev0",
28 probes=probes,
29 )
30
31
32 class _FakeFig:
33 def tight_layout(self) -> None: # pragma: no cover — trivial
34 return None
35
36
37 class _FakeAx:
38 def __init__(self) -> None:
39 self.calls: list[str] = []
40
41 def bar(self, *a, **k): # type: ignore[no-untyped-def]
42 self.calls.append("bar")
43
44 def plot(self, *a, **k): # type: ignore[no-untyped-def]
45 self.calls.append("plot")
46
47 def hist(self, *a, **k): # type: ignore[no-untyped-def]
48 self.calls.append("hist")
49
50 def axhline(self, *a, **k): # type: ignore[no-untyped-def]
51 return None
52
53 def axvline(self, *a, **k): # type: ignore[no-untyped-def]
54 return None
55
56 def set_xticks(self, *a, **k): # type: ignore[no-untyped-def]
57 return None
58
59 def set_xticklabels(self, *a, **k): # type: ignore[no-untyped-def]
60 return None
61
62 def set_xlabel(self, *a, **k): # type: ignore[no-untyped-def]
63 return None
64
65 def set_ylabel(self, *a, **k): # type: ignore[no-untyped-def]
66 return None
67
68 def set_title(self, *a, **k): # type: ignore[no-untyped-def]
69 return None
70
71 def legend(self, *a, **k): # type: ignore[no-untyped-def]
72 return None
73
74
75 @pytest.fixture
76 def fake_mpl(monkeypatch: pytest.MonkeyPatch) -> _FakeAx:
77 ax = _FakeAx()
78
79 def _subplots(*a, **k): # type: ignore[no-untyped-def]
80 return _FakeFig(), ax
81
82 plt = types.ModuleType("matplotlib.pyplot")
83 plt.subplots = _subplots # type: ignore[attr-defined]
84 mpl_pkg = types.ModuleType("matplotlib")
85 monkeypatch.setitem(sys.modules, "matplotlib", mpl_pkg)
86 monkeypatch.setitem(sys.modules, "matplotlib.pyplot", plt)
87 return ax
88
89
90 def test_section_sis_plot_uses_per_section_evidence(fake_mpl: _FakeAx) -> None:
91 from dlm_sway.visualize import plot_section_sis
92
93 suite = _suite_with(
94 ProbeResult(
95 name="sis",
96 kind="section_internalization",
97 verdict=Verdict.PASS,
98 score=0.75,
99 raw=0.1,
100 evidence={
101 "per_section": [
102 {
103 "section_id": "a",
104 "kind": "prose",
105 "tag": None,
106 "base_nll": 3.0,
107 "ft_nll": 2.5,
108 "own_lift": 0.17,
109 "leak_lift": 0.02,
110 "effective_sis": 0.15,
111 "passed": True,
112 },
113 {
114 "section_id": "b",
115 "kind": "instruction",
116 "tag": "intro",
117 "base_nll": 4.0,
118 "ft_nll": 3.9,
119 "own_lift": 0.025,
120 "leak_lift": 0.03,
121 "effective_sis": -0.005,
122 "passed": False,
123 },
124 ],
125 "per_section_threshold": 0.05,
126 },
127 )
128 )
129 plot_section_sis(suite)
130 assert "bar" in fake_mpl.calls
131
132
133 def test_adapter_ablation_plot(fake_mpl: _FakeAx) -> None:
134 from dlm_sway.visualize import plot_adapter_ablation
135
136 suite = _suite_with(
137 ProbeResult(
138 name="abl",
139 kind="adapter_ablation",
140 verdict=Verdict.PASS,
141 score=0.8,
142 raw=0.9,
143 evidence={
144 "lambdas": [0.0, 0.5, 1.0, 1.25],
145 "mean_divergence_per_lambda": [0.0, 0.5, 1.0, 1.1],
146 "linearity": 0.91,
147 "saturation_lambda": 0.75,
148 "overshoot": 1.1,
149 },
150 )
151 )
152 plot_adapter_ablation(suite)
153 assert "plot" in fake_mpl.calls
154
155
156 def test_kl_histogram_plot(fake_mpl: _FakeAx) -> None:
157 from dlm_sway.visualize import plot_kl_histogram
158
159 suite = _suite_with(
160 ProbeResult(
161 name="dk",
162 kind="delta_kl",
163 verdict=Verdict.PASS,
164 score=0.7,
165 raw=0.1,
166 evidence={"per_prompt": [0.05, 0.1, 0.12, 0.09, 0.15], "divergence_kind": "js"},
167 )
168 )
169 plot_kl_histogram(suite)
170 assert "hist" in fake_mpl.calls
171
172
173 def test_raises_when_matplotlib_missing(monkeypatch: pytest.MonkeyPatch) -> None:
174 # Purge matplotlib modules and block imports.
175 for mod in list(sys.modules):
176 if mod == "matplotlib" or mod.startswith("matplotlib."):
177 monkeypatch.delitem(sys.modules, mod, raising=False)
178
179 import builtins
180
181 real_import = builtins.__import__
182
183 def fake_import(name: str, *a, **k): # type: ignore[no-untyped-def]
184 if name == "matplotlib" or name.startswith("matplotlib."):
185 raise ImportError("matplotlib missing in this venv")
186 return real_import(name, *a, **k)
187
188 monkeypatch.setattr(builtins, "__import__", fake_import)
189
190 from dlm_sway.visualize import plot_section_sis
191
192 suite = _suite_with()
193 with pytest.raises(BackendNotAvailableError):
194 plot_section_sis(suite)
195
196
197 def test_raises_when_no_matching_probe(fake_mpl: _FakeAx) -> None:
198 from dlm_sway.visualize import plot_section_sis
199
200 suite = _suite_with() # empty — no section_internalization probe
201 with pytest.raises(ValueError, match="section_internalization"):
202 plot_section_sis(suite)