| 1 | """Tests for :mod:`dlm_sway.visualize`. |
| 2 | |
| 3 | Exercises the error path (matplotlib missing) and the happy path when |
| 4 | the module is present by stubbing ``matplotlib.pyplot`` via sys.modules. |
| 5 | """ |
| 6 | |
| 7 | from __future__ import annotations |
| 8 | |
| 9 | import sys |
| 10 | import types |
| 11 | from datetime import timedelta |
| 12 | |
| 13 | import pytest |
| 14 | |
| 15 | from dlm_sway.core.errors import BackendNotAvailableError |
| 16 | from dlm_sway.core.result import ProbeResult, SuiteResult, Verdict, utcnow |
| 17 | |
| 18 | |
| 19 | def _suite_with(*probes: ProbeResult) -> SuiteResult: |
| 20 | started = utcnow() |
| 21 | return SuiteResult( |
| 22 | spec_path="sway.yaml", |
| 23 | started_at=started, |
| 24 | finished_at=started + timedelta(seconds=1), |
| 25 | base_model_id="b", |
| 26 | adapter_id="a", |
| 27 | sway_version="0.1.0.dev0", |
| 28 | probes=probes, |
| 29 | ) |
| 30 | |
| 31 | |
| 32 | class _FakeFig: |
| 33 | def tight_layout(self) -> None: # pragma: no cover — trivial |
| 34 | return None |
| 35 | |
| 36 | |
| 37 | class _FakeAx: |
| 38 | def __init__(self) -> None: |
| 39 | self.calls: list[str] = [] |
| 40 | |
| 41 | def bar(self, *a, **k): # type: ignore[no-untyped-def] |
| 42 | self.calls.append("bar") |
| 43 | |
| 44 | def plot(self, *a, **k): # type: ignore[no-untyped-def] |
| 45 | self.calls.append("plot") |
| 46 | |
| 47 | def hist(self, *a, **k): # type: ignore[no-untyped-def] |
| 48 | self.calls.append("hist") |
| 49 | |
| 50 | def axhline(self, *a, **k): # type: ignore[no-untyped-def] |
| 51 | return None |
| 52 | |
| 53 | def axvline(self, *a, **k): # type: ignore[no-untyped-def] |
| 54 | return None |
| 55 | |
| 56 | def set_xticks(self, *a, **k): # type: ignore[no-untyped-def] |
| 57 | return None |
| 58 | |
| 59 | def set_xticklabels(self, *a, **k): # type: ignore[no-untyped-def] |
| 60 | return None |
| 61 | |
| 62 | def set_xlabel(self, *a, **k): # type: ignore[no-untyped-def] |
| 63 | return None |
| 64 | |
| 65 | def set_ylabel(self, *a, **k): # type: ignore[no-untyped-def] |
| 66 | return None |
| 67 | |
| 68 | def set_title(self, *a, **k): # type: ignore[no-untyped-def] |
| 69 | return None |
| 70 | |
| 71 | def legend(self, *a, **k): # type: ignore[no-untyped-def] |
| 72 | return None |
| 73 | |
| 74 | |
| 75 | @pytest.fixture |
| 76 | def fake_mpl(monkeypatch: pytest.MonkeyPatch) -> _FakeAx: |
| 77 | ax = _FakeAx() |
| 78 | |
| 79 | def _subplots(*a, **k): # type: ignore[no-untyped-def] |
| 80 | return _FakeFig(), ax |
| 81 | |
| 82 | plt = types.ModuleType("matplotlib.pyplot") |
| 83 | plt.subplots = _subplots # type: ignore[attr-defined] |
| 84 | mpl_pkg = types.ModuleType("matplotlib") |
| 85 | monkeypatch.setitem(sys.modules, "matplotlib", mpl_pkg) |
| 86 | monkeypatch.setitem(sys.modules, "matplotlib.pyplot", plt) |
| 87 | return ax |
| 88 | |
| 89 | |
| 90 | def test_section_sis_plot_uses_per_section_evidence(fake_mpl: _FakeAx) -> None: |
| 91 | from dlm_sway.visualize import plot_section_sis |
| 92 | |
| 93 | suite = _suite_with( |
| 94 | ProbeResult( |
| 95 | name="sis", |
| 96 | kind="section_internalization", |
| 97 | verdict=Verdict.PASS, |
| 98 | score=0.75, |
| 99 | raw=0.1, |
| 100 | evidence={ |
| 101 | "per_section": [ |
| 102 | { |
| 103 | "section_id": "a", |
| 104 | "kind": "prose", |
| 105 | "tag": None, |
| 106 | "base_nll": 3.0, |
| 107 | "ft_nll": 2.5, |
| 108 | "own_lift": 0.17, |
| 109 | "leak_lift": 0.02, |
| 110 | "effective_sis": 0.15, |
| 111 | "passed": True, |
| 112 | }, |
| 113 | { |
| 114 | "section_id": "b", |
| 115 | "kind": "instruction", |
| 116 | "tag": "intro", |
| 117 | "base_nll": 4.0, |
| 118 | "ft_nll": 3.9, |
| 119 | "own_lift": 0.025, |
| 120 | "leak_lift": 0.03, |
| 121 | "effective_sis": -0.005, |
| 122 | "passed": False, |
| 123 | }, |
| 124 | ], |
| 125 | "per_section_threshold": 0.05, |
| 126 | }, |
| 127 | ) |
| 128 | ) |
| 129 | plot_section_sis(suite) |
| 130 | assert "bar" in fake_mpl.calls |
| 131 | |
| 132 | |
| 133 | def test_adapter_ablation_plot(fake_mpl: _FakeAx) -> None: |
| 134 | from dlm_sway.visualize import plot_adapter_ablation |
| 135 | |
| 136 | suite = _suite_with( |
| 137 | ProbeResult( |
| 138 | name="abl", |
| 139 | kind="adapter_ablation", |
| 140 | verdict=Verdict.PASS, |
| 141 | score=0.8, |
| 142 | raw=0.9, |
| 143 | evidence={ |
| 144 | "lambdas": [0.0, 0.5, 1.0, 1.25], |
| 145 | "mean_divergence_per_lambda": [0.0, 0.5, 1.0, 1.1], |
| 146 | "linearity": 0.91, |
| 147 | "saturation_lambda": 0.75, |
| 148 | "overshoot": 1.1, |
| 149 | }, |
| 150 | ) |
| 151 | ) |
| 152 | plot_adapter_ablation(suite) |
| 153 | assert "plot" in fake_mpl.calls |
| 154 | |
| 155 | |
| 156 | def test_kl_histogram_plot(fake_mpl: _FakeAx) -> None: |
| 157 | from dlm_sway.visualize import plot_kl_histogram |
| 158 | |
| 159 | suite = _suite_with( |
| 160 | ProbeResult( |
| 161 | name="dk", |
| 162 | kind="delta_kl", |
| 163 | verdict=Verdict.PASS, |
| 164 | score=0.7, |
| 165 | raw=0.1, |
| 166 | evidence={"per_prompt": [0.05, 0.1, 0.12, 0.09, 0.15], "divergence_kind": "js"}, |
| 167 | ) |
| 168 | ) |
| 169 | plot_kl_histogram(suite) |
| 170 | assert "hist" in fake_mpl.calls |
| 171 | |
| 172 | |
| 173 | def test_raises_when_matplotlib_missing(monkeypatch: pytest.MonkeyPatch) -> None: |
| 174 | # Purge matplotlib modules and block imports. |
| 175 | for mod in list(sys.modules): |
| 176 | if mod == "matplotlib" or mod.startswith("matplotlib."): |
| 177 | monkeypatch.delitem(sys.modules, mod, raising=False) |
| 178 | |
| 179 | import builtins |
| 180 | |
| 181 | real_import = builtins.__import__ |
| 182 | |
| 183 | def fake_import(name: str, *a, **k): # type: ignore[no-untyped-def] |
| 184 | if name == "matplotlib" or name.startswith("matplotlib."): |
| 185 | raise ImportError("matplotlib missing in this venv") |
| 186 | return real_import(name, *a, **k) |
| 187 | |
| 188 | monkeypatch.setattr(builtins, "__import__", fake_import) |
| 189 | |
| 190 | from dlm_sway.visualize import plot_section_sis |
| 191 | |
| 192 | suite = _suite_with() |
| 193 | with pytest.raises(BackendNotAvailableError): |
| 194 | plot_section_sis(suite) |
| 195 | |
| 196 | |
| 197 | def test_raises_when_no_matching_probe(fake_mpl: _FakeAx) -> None: |
| 198 | from dlm_sway.visualize import plot_section_sis |
| 199 | |
| 200 | suite = _suite_with() # empty — no section_internalization probe |
| 201 | with pytest.raises(ValueError, match="section_internalization"): |
| 202 | plot_section_sis(suite) |