sway(viz): matplotlib plots for SIS, adapter ablation, KL histogram (viz extra)
- SHA
228d4049a2215a71a26a361fe8db2926bb5a429f- Parents
-
2e074c6 - Tree
fe511b1
228d404
228d4049a2215a71a26a361fe8db2926bb5a429f2e074c6
fe511b1| Status | File | + | - |
|---|---|---|---|
| M |
pyproject.toml
|
2 | 0 |
| A |
src/dlm_sway/visualize.py
|
137 | 0 |
| A |
tests/unit/test_visualize.py
|
202 | 0 |
pyproject.tomlmodified@@ -188,6 +188,8 @@ module = [ | ||
| 188 | 188 | "spacy.*", |
| 189 | 189 | "textstat.*", |
| 190 | 190 | "nlpaug.*", |
| 191 | + "matplotlib", | |
| 192 | + "matplotlib.*", | |
| 191 | 193 | "huggingface_hub.*", |
| 192 | 194 | "dlm.*", |
| 193 | 195 | ] |
src/dlm_sway/visualize.pyadded@@ -0,0 +1,137 @@ | ||
| 1 | +"""Optional matplotlib-based visualizations. | |
| 2 | + | |
| 3 | +Behind the ``viz`` extra. Three functions cover the three plots that | |
| 4 | +make the sway report come alive in a notebook or saved PNG: | |
| 5 | + | |
| 6 | +- :func:`plot_section_sis`: per-section bar chart of effective SIS | |
| 7 | + (the flagship attribution view). | |
| 8 | +- :func:`plot_adapter_ablation`: the λ-scaled divergence curve — the | |
| 9 | + sway signature plot. | |
| 10 | +- :func:`plot_kl_histogram`: distribution of per-prompt KL divergences | |
| 11 | + (the raw data behind A1 DeltaKL). | |
| 12 | + | |
| 13 | +Each function raises :class:`~dlm_sway.core.errors.BackendNotAvailableError` | |
| 14 | +with a pip hint when matplotlib isn't installed. No function writes to | |
| 15 | +disk on your behalf — the caller decides (``fig.savefig(...)``). | |
| 16 | +""" | |
| 17 | + | |
| 18 | +from __future__ import annotations | |
| 19 | + | |
| 20 | +from typing import Any | |
| 21 | + | |
| 22 | +from dlm_sway.core.errors import BackendNotAvailableError | |
| 23 | +from dlm_sway.core.result import SuiteResult | |
| 24 | + | |
| 25 | + | |
| 26 | +def _require_mpl() -> Any: | |
| 27 | + try: | |
| 28 | + import matplotlib.pyplot as plt | |
| 29 | + | |
| 30 | + return plt | |
| 31 | + except ImportError as exc: | |
| 32 | + raise BackendNotAvailableError( | |
| 33 | + "visualize", | |
| 34 | + extra="viz", | |
| 35 | + hint="sway's visualization module needs matplotlib.", | |
| 36 | + ) from exc | |
| 37 | + | |
| 38 | + | |
| 39 | +def plot_section_sis(suite: SuiteResult) -> Any: | |
| 40 | + """Render a per-section ``effective_sis`` bar chart. | |
| 41 | + | |
| 42 | + Returns the matplotlib ``Figure``; the caller handles display / save. | |
| 43 | + """ | |
| 44 | + plt = _require_mpl() | |
| 45 | + | |
| 46 | + probe = _find_probe(suite, "section_internalization") | |
| 47 | + if probe is None or not probe.evidence.get("per_section"): | |
| 48 | + raise ValueError("suite has no section_internalization evidence to plot") | |
| 49 | + | |
| 50 | + rows: list[dict[str, Any]] = list(probe.evidence["per_section"]) | |
| 51 | + labels = [f"{row['tag'] or row['section_id'][:8]}\n({row['kind']})" for row in rows] | |
| 52 | + values = [float(row["effective_sis"]) for row in rows] | |
| 53 | + colors = ["#2ca02c" if row["passed"] else "#d62728" for row in rows] | |
| 54 | + | |
| 55 | + fig, ax = plt.subplots(figsize=(max(6.0, 0.7 * len(rows)), 4.0)) | |
| 56 | + ax.bar(range(len(rows)), values, color=colors) | |
| 57 | + ax.axhline( | |
| 58 | + float(probe.evidence.get("per_section_threshold", 0.0)), | |
| 59 | + color="gray", | |
| 60 | + linestyle="--", | |
| 61 | + linewidth=1, | |
| 62 | + label="threshold", | |
| 63 | + ) | |
| 64 | + ax.set_xticks(range(len(rows))) | |
| 65 | + ax.set_xticklabels(labels, rotation=30, ha="right") | |
| 66 | + ax.set_ylabel("effective SIS") | |
| 67 | + ax.set_title("Section Internalization Score") | |
| 68 | + ax.legend(loc="best") | |
| 69 | + fig.tight_layout() | |
| 70 | + return fig | |
| 71 | + | |
| 72 | + | |
| 73 | +def plot_adapter_ablation(suite: SuiteResult) -> Any: | |
| 74 | + """Render the signature λ-scaled divergence curve.""" | |
| 75 | + plt = _require_mpl() | |
| 76 | + | |
| 77 | + probe = _find_probe(suite, "adapter_ablation") | |
| 78 | + if probe is None or not probe.evidence.get("lambdas"): | |
| 79 | + raise ValueError("suite has no adapter_ablation evidence to plot") | |
| 80 | + | |
| 81 | + lambdas = list(probe.evidence["lambdas"]) | |
| 82 | + divs = list(probe.evidence["mean_divergence_per_lambda"]) | |
| 83 | + | |
| 84 | + fig, ax = plt.subplots(figsize=(7.0, 4.0)) | |
| 85 | + ax.plot(lambdas, divs, marker="o", linewidth=2, color="#1f77b4") | |
| 86 | + ax.axvline(1.0, color="gray", linestyle=":", linewidth=1, label="λ=1 (trained)") | |
| 87 | + sat = probe.evidence.get("saturation_lambda") | |
| 88 | + if sat is not None: | |
| 89 | + ax.axvline( | |
| 90 | + float(sat), | |
| 91 | + color="#2ca02c", | |
| 92 | + linestyle="--", | |
| 93 | + linewidth=1, | |
| 94 | + label=f"sat λ={float(sat):.2f}", | |
| 95 | + ) | |
| 96 | + ax.set_xlabel("λ (adapter scale)") | |
| 97 | + ax.set_ylabel("mean JS divergence vs λ=0") | |
| 98 | + ax.set_title( | |
| 99 | + f"Adapter Ablation (R²={float(probe.evidence.get('linearity', 0.0)):.2f}, " | |
| 100 | + f"overshoot={float(probe.evidence.get('overshoot', 0.0)):.2f})" | |
| 101 | + ) | |
| 102 | + ax.legend(loc="best") | |
| 103 | + fig.tight_layout() | |
| 104 | + return fig | |
| 105 | + | |
| 106 | + | |
| 107 | +def plot_kl_histogram(suite: SuiteResult) -> Any: | |
| 108 | + """Render the per-prompt KL distribution from a DeltaKL probe.""" | |
| 109 | + plt = _require_mpl() | |
| 110 | + | |
| 111 | + probe = _find_probe(suite, "delta_kl") | |
| 112 | + if probe is None or not probe.evidence.get("per_prompt"): | |
| 113 | + raise ValueError("suite has no delta_kl evidence to plot") | |
| 114 | + | |
| 115 | + values = list(probe.evidence["per_prompt"]) | |
| 116 | + fig, ax = plt.subplots(figsize=(7.0, 4.0)) | |
| 117 | + ax.hist(values, bins=max(5, min(20, len(values) // 2)), color="#ff7f0e", edgecolor="white") | |
| 118 | + ax.axvline( | |
| 119 | + float(probe.raw or 0.0), | |
| 120 | + color="black", | |
| 121 | + linestyle="--", | |
| 122 | + linewidth=1, | |
| 123 | + label=f"mean={float(probe.raw or 0.0):.3f}", | |
| 124 | + ) | |
| 125 | + ax.set_xlabel(probe.evidence.get("divergence_kind", "divergence")) | |
| 126 | + ax.set_ylabel("count") | |
| 127 | + ax.set_title("DeltaKL — per-prompt distribution") | |
| 128 | + ax.legend(loc="best") | |
| 129 | + fig.tight_layout() | |
| 130 | + return fig | |
| 131 | + | |
| 132 | + | |
| 133 | +def _find_probe(suite: SuiteResult, kind: str) -> Any: | |
| 134 | + for p in suite.probes: | |
| 135 | + if p.kind == kind: | |
| 136 | + return p | |
| 137 | + return None | |
tests/unit/test_visualize.pyadded@@ -0,0 +1,202 @@ | ||
| 1 | +"""Tests for :mod:`dlm_sway.visualize`. | |
| 2 | + | |
| 3 | +Exercises the error path (matplotlib missing) and the happy path when | |
| 4 | +the module is present by stubbing ``matplotlib.pyplot`` via sys.modules. | |
| 5 | +""" | |
| 6 | + | |
| 7 | +from __future__ import annotations | |
| 8 | + | |
| 9 | +import sys | |
| 10 | +import types | |
| 11 | +from datetime import timedelta | |
| 12 | + | |
| 13 | +import pytest | |
| 14 | + | |
| 15 | +from dlm_sway.core.errors import BackendNotAvailableError | |
| 16 | +from dlm_sway.core.result import ProbeResult, SuiteResult, Verdict, utcnow | |
| 17 | + | |
| 18 | + | |
| 19 | +def _suite_with(*probes: ProbeResult) -> SuiteResult: | |
| 20 | + started = utcnow() | |
| 21 | + return SuiteResult( | |
| 22 | + spec_path="sway.yaml", | |
| 23 | + started_at=started, | |
| 24 | + finished_at=started + timedelta(seconds=1), | |
| 25 | + base_model_id="b", | |
| 26 | + adapter_id="a", | |
| 27 | + sway_version="0.1.0.dev0", | |
| 28 | + probes=probes, | |
| 29 | + ) | |
| 30 | + | |
| 31 | + | |
| 32 | +class _FakeFig: | |
| 33 | + def tight_layout(self) -> None: # pragma: no cover — trivial | |
| 34 | + return None | |
| 35 | + | |
| 36 | + | |
| 37 | +class _FakeAx: | |
| 38 | + def __init__(self) -> None: | |
| 39 | + self.calls: list[str] = [] | |
| 40 | + | |
| 41 | + def bar(self, *a, **k): # type: ignore[no-untyped-def] | |
| 42 | + self.calls.append("bar") | |
| 43 | + | |
| 44 | + def plot(self, *a, **k): # type: ignore[no-untyped-def] | |
| 45 | + self.calls.append("plot") | |
| 46 | + | |
| 47 | + def hist(self, *a, **k): # type: ignore[no-untyped-def] | |
| 48 | + self.calls.append("hist") | |
| 49 | + | |
| 50 | + def axhline(self, *a, **k): # type: ignore[no-untyped-def] | |
| 51 | + return None | |
| 52 | + | |
| 53 | + def axvline(self, *a, **k): # type: ignore[no-untyped-def] | |
| 54 | + return None | |
| 55 | + | |
| 56 | + def set_xticks(self, *a, **k): # type: ignore[no-untyped-def] | |
| 57 | + return None | |
| 58 | + | |
| 59 | + def set_xticklabels(self, *a, **k): # type: ignore[no-untyped-def] | |
| 60 | + return None | |
| 61 | + | |
| 62 | + def set_xlabel(self, *a, **k): # type: ignore[no-untyped-def] | |
| 63 | + return None | |
| 64 | + | |
| 65 | + def set_ylabel(self, *a, **k): # type: ignore[no-untyped-def] | |
| 66 | + return None | |
| 67 | + | |
| 68 | + def set_title(self, *a, **k): # type: ignore[no-untyped-def] | |
| 69 | + return None | |
| 70 | + | |
| 71 | + def legend(self, *a, **k): # type: ignore[no-untyped-def] | |
| 72 | + return None | |
| 73 | + | |
| 74 | + | |
| 75 | +@pytest.fixture | |
| 76 | +def fake_mpl(monkeypatch: pytest.MonkeyPatch) -> _FakeAx: | |
| 77 | + ax = _FakeAx() | |
| 78 | + | |
| 79 | + def _subplots(*a, **k): # type: ignore[no-untyped-def] | |
| 80 | + return _FakeFig(), ax | |
| 81 | + | |
| 82 | + plt = types.ModuleType("matplotlib.pyplot") | |
| 83 | + plt.subplots = _subplots # type: ignore[attr-defined] | |
| 84 | + mpl_pkg = types.ModuleType("matplotlib") | |
| 85 | + monkeypatch.setitem(sys.modules, "matplotlib", mpl_pkg) | |
| 86 | + monkeypatch.setitem(sys.modules, "matplotlib.pyplot", plt) | |
| 87 | + return ax | |
| 88 | + | |
| 89 | + | |
| 90 | +def test_section_sis_plot_uses_per_section_evidence(fake_mpl: _FakeAx) -> None: | |
| 91 | + from dlm_sway.visualize import plot_section_sis | |
| 92 | + | |
| 93 | + suite = _suite_with( | |
| 94 | + ProbeResult( | |
| 95 | + name="sis", | |
| 96 | + kind="section_internalization", | |
| 97 | + verdict=Verdict.PASS, | |
| 98 | + score=0.75, | |
| 99 | + raw=0.1, | |
| 100 | + evidence={ | |
| 101 | + "per_section": [ | |
| 102 | + { | |
| 103 | + "section_id": "a", | |
| 104 | + "kind": "prose", | |
| 105 | + "tag": None, | |
| 106 | + "base_nll": 3.0, | |
| 107 | + "ft_nll": 2.5, | |
| 108 | + "own_lift": 0.17, | |
| 109 | + "leak_lift": 0.02, | |
| 110 | + "effective_sis": 0.15, | |
| 111 | + "passed": True, | |
| 112 | + }, | |
| 113 | + { | |
| 114 | + "section_id": "b", | |
| 115 | + "kind": "instruction", | |
| 116 | + "tag": "intro", | |
| 117 | + "base_nll": 4.0, | |
| 118 | + "ft_nll": 3.9, | |
| 119 | + "own_lift": 0.025, | |
| 120 | + "leak_lift": 0.03, | |
| 121 | + "effective_sis": -0.005, | |
| 122 | + "passed": False, | |
| 123 | + }, | |
| 124 | + ], | |
| 125 | + "per_section_threshold": 0.05, | |
| 126 | + }, | |
| 127 | + ) | |
| 128 | + ) | |
| 129 | + plot_section_sis(suite) | |
| 130 | + assert "bar" in fake_mpl.calls | |
| 131 | + | |
| 132 | + | |
| 133 | +def test_adapter_ablation_plot(fake_mpl: _FakeAx) -> None: | |
| 134 | + from dlm_sway.visualize import plot_adapter_ablation | |
| 135 | + | |
| 136 | + suite = _suite_with( | |
| 137 | + ProbeResult( | |
| 138 | + name="abl", | |
| 139 | + kind="adapter_ablation", | |
| 140 | + verdict=Verdict.PASS, | |
| 141 | + score=0.8, | |
| 142 | + raw=0.9, | |
| 143 | + evidence={ | |
| 144 | + "lambdas": [0.0, 0.5, 1.0, 1.25], | |
| 145 | + "mean_divergence_per_lambda": [0.0, 0.5, 1.0, 1.1], | |
| 146 | + "linearity": 0.91, | |
| 147 | + "saturation_lambda": 0.75, | |
| 148 | + "overshoot": 1.1, | |
| 149 | + }, | |
| 150 | + ) | |
| 151 | + ) | |
| 152 | + plot_adapter_ablation(suite) | |
| 153 | + assert "plot" in fake_mpl.calls | |
| 154 | + | |
| 155 | + | |
| 156 | +def test_kl_histogram_plot(fake_mpl: _FakeAx) -> None: | |
| 157 | + from dlm_sway.visualize import plot_kl_histogram | |
| 158 | + | |
| 159 | + suite = _suite_with( | |
| 160 | + ProbeResult( | |
| 161 | + name="dk", | |
| 162 | + kind="delta_kl", | |
| 163 | + verdict=Verdict.PASS, | |
| 164 | + score=0.7, | |
| 165 | + raw=0.1, | |
| 166 | + evidence={"per_prompt": [0.05, 0.1, 0.12, 0.09, 0.15], "divergence_kind": "js"}, | |
| 167 | + ) | |
| 168 | + ) | |
| 169 | + plot_kl_histogram(suite) | |
| 170 | + assert "hist" in fake_mpl.calls | |
| 171 | + | |
| 172 | + | |
| 173 | +def test_raises_when_matplotlib_missing(monkeypatch: pytest.MonkeyPatch) -> None: | |
| 174 | + # Purge matplotlib modules and block imports. | |
| 175 | + for mod in list(sys.modules): | |
| 176 | + if mod == "matplotlib" or mod.startswith("matplotlib."): | |
| 177 | + monkeypatch.delitem(sys.modules, mod, raising=False) | |
| 178 | + | |
| 179 | + import builtins | |
| 180 | + | |
| 181 | + real_import = builtins.__import__ | |
| 182 | + | |
| 183 | + def fake_import(name: str, *a, **k): # type: ignore[no-untyped-def] | |
| 184 | + if name == "matplotlib" or name.startswith("matplotlib."): | |
| 185 | + raise ImportError("matplotlib missing in this venv") | |
| 186 | + return real_import(name, *a, **k) | |
| 187 | + | |
| 188 | + monkeypatch.setattr(builtins, "__import__", fake_import) | |
| 189 | + | |
| 190 | + from dlm_sway.visualize import plot_section_sis | |
| 191 | + | |
| 192 | + suite = _suite_with() | |
| 193 | + with pytest.raises(BackendNotAvailableError): | |
| 194 | + plot_section_sis(suite) | |
| 195 | + | |
| 196 | + | |
| 197 | +def test_raises_when_no_matching_probe(fake_mpl: _FakeAx) -> None: | |
| 198 | + from dlm_sway.visualize import plot_section_sis | |
| 199 | + | |
| 200 | + suite = _suite_with() # empty — no section_internalization probe | |
| 201 | + with pytest.raises(ValueError, match="section_internalization"): | |
| 202 | + plot_section_sis(suite) | |