tenseleyflow/sway / 228d404

Browse files

sway(viz): matplotlib plots for SIS, adapter ablation, KL histogram (viz extra)

Authored by espadonne
SHA
228d4049a2215a71a26a361fe8db2926bb5a429f
Parents
2e074c6
Tree
fe511b1

3 changed files

StatusFile+-
M pyproject.toml 2 0
A src/dlm_sway/visualize.py 137 0
A tests/unit/test_visualize.py 202 0
pyproject.tomlmodified
@@ -188,6 +188,8 @@ module = [
188188
     "spacy.*",
189189
     "textstat.*",
190190
     "nlpaug.*",
191
+    "matplotlib",
192
+    "matplotlib.*",
191193
     "huggingface_hub.*",
192194
     "dlm.*",
193195
 ]
src/dlm_sway/visualize.pyadded
@@ -0,0 +1,137 @@
1
+"""Optional matplotlib-based visualizations.
2
+
3
+Behind the ``viz`` extra. Three functions cover the three plots that
4
+make the sway report come alive in a notebook or saved PNG:
5
+
6
+- :func:`plot_section_sis`: per-section bar chart of effective SIS
7
+  (the flagship attribution view).
8
+- :func:`plot_adapter_ablation`: the λ-scaled divergence curve — the
9
+  sway signature plot.
10
+- :func:`plot_kl_histogram`: distribution of per-prompt KL divergences
11
+  (the raw data behind A1 DeltaKL).
12
+
13
+Each function raises :class:`~dlm_sway.core.errors.BackendNotAvailableError`
14
+with a pip hint when matplotlib isn't installed. No function writes to
15
+disk on your behalf — the caller decides (``fig.savefig(...)``).
16
+"""
17
+
18
+from __future__ import annotations
19
+
20
+from typing import Any
21
+
22
+from dlm_sway.core.errors import BackendNotAvailableError
23
+from dlm_sway.core.result import SuiteResult
24
+
25
+
26
+def _require_mpl() -> Any:
27
+    try:
28
+        import matplotlib.pyplot as plt
29
+
30
+        return plt
31
+    except ImportError as exc:
32
+        raise BackendNotAvailableError(
33
+            "visualize",
34
+            extra="viz",
35
+            hint="sway's visualization module needs matplotlib.",
36
+        ) from exc
37
+
38
+
39
+def plot_section_sis(suite: SuiteResult) -> Any:
40
+    """Render a per-section ``effective_sis`` bar chart.
41
+
42
+    Returns the matplotlib ``Figure``; the caller handles display / save.
43
+    """
44
+    plt = _require_mpl()
45
+
46
+    probe = _find_probe(suite, "section_internalization")
47
+    if probe is None or not probe.evidence.get("per_section"):
48
+        raise ValueError("suite has no section_internalization evidence to plot")
49
+
50
+    rows: list[dict[str, Any]] = list(probe.evidence["per_section"])
51
+    labels = [f"{row['tag'] or row['section_id'][:8]}\n({row['kind']})" for row in rows]
52
+    values = [float(row["effective_sis"]) for row in rows]
53
+    colors = ["#2ca02c" if row["passed"] else "#d62728" for row in rows]
54
+
55
+    fig, ax = plt.subplots(figsize=(max(6.0, 0.7 * len(rows)), 4.0))
56
+    ax.bar(range(len(rows)), values, color=colors)
57
+    ax.axhline(
58
+        float(probe.evidence.get("per_section_threshold", 0.0)),
59
+        color="gray",
60
+        linestyle="--",
61
+        linewidth=1,
62
+        label="threshold",
63
+    )
64
+    ax.set_xticks(range(len(rows)))
65
+    ax.set_xticklabels(labels, rotation=30, ha="right")
66
+    ax.set_ylabel("effective SIS")
67
+    ax.set_title("Section Internalization Score")
68
+    ax.legend(loc="best")
69
+    fig.tight_layout()
70
+    return fig
71
+
72
+
73
+def plot_adapter_ablation(suite: SuiteResult) -> Any:
74
+    """Render the signature λ-scaled divergence curve."""
75
+    plt = _require_mpl()
76
+
77
+    probe = _find_probe(suite, "adapter_ablation")
78
+    if probe is None or not probe.evidence.get("lambdas"):
79
+        raise ValueError("suite has no adapter_ablation evidence to plot")
80
+
81
+    lambdas = list(probe.evidence["lambdas"])
82
+    divs = list(probe.evidence["mean_divergence_per_lambda"])
83
+
84
+    fig, ax = plt.subplots(figsize=(7.0, 4.0))
85
+    ax.plot(lambdas, divs, marker="o", linewidth=2, color="#1f77b4")
86
+    ax.axvline(1.0, color="gray", linestyle=":", linewidth=1, label="λ=1 (trained)")
87
+    sat = probe.evidence.get("saturation_lambda")
88
+    if sat is not None:
89
+        ax.axvline(
90
+            float(sat),
91
+            color="#2ca02c",
92
+            linestyle="--",
93
+            linewidth=1,
94
+            label=f"sat λ={float(sat):.2f}",
95
+        )
96
+    ax.set_xlabel("λ (adapter scale)")
97
+    ax.set_ylabel("mean JS divergence vs λ=0")
98
+    ax.set_title(
99
+        f"Adapter Ablation (R²={float(probe.evidence.get('linearity', 0.0)):.2f}, "
100
+        f"overshoot={float(probe.evidence.get('overshoot', 0.0)):.2f})"
101
+    )
102
+    ax.legend(loc="best")
103
+    fig.tight_layout()
104
+    return fig
105
+
106
+
107
+def plot_kl_histogram(suite: SuiteResult) -> Any:
108
+    """Render the per-prompt KL distribution from a DeltaKL probe."""
109
+    plt = _require_mpl()
110
+
111
+    probe = _find_probe(suite, "delta_kl")
112
+    if probe is None or not probe.evidence.get("per_prompt"):
113
+        raise ValueError("suite has no delta_kl evidence to plot")
114
+
115
+    values = list(probe.evidence["per_prompt"])
116
+    fig, ax = plt.subplots(figsize=(7.0, 4.0))
117
+    ax.hist(values, bins=max(5, min(20, len(values) // 2)), color="#ff7f0e", edgecolor="white")
118
+    ax.axvline(
119
+        float(probe.raw or 0.0),
120
+        color="black",
121
+        linestyle="--",
122
+        linewidth=1,
123
+        label=f"mean={float(probe.raw or 0.0):.3f}",
124
+    )
125
+    ax.set_xlabel(probe.evidence.get("divergence_kind", "divergence"))
126
+    ax.set_ylabel("count")
127
+    ax.set_title("DeltaKL — per-prompt distribution")
128
+    ax.legend(loc="best")
129
+    fig.tight_layout()
130
+    return fig
131
+
132
+
133
+def _find_probe(suite: SuiteResult, kind: str) -> Any:
134
+    for p in suite.probes:
135
+        if p.kind == kind:
136
+            return p
137
+    return None
tests/unit/test_visualize.pyadded
@@ -0,0 +1,202 @@
1
+"""Tests for :mod:`dlm_sway.visualize`.
2
+
3
+Exercises the error path (matplotlib missing) and the happy path when
4
+the module is present by stubbing ``matplotlib.pyplot`` via sys.modules.
5
+"""
6
+
7
+from __future__ import annotations
8
+
9
+import sys
10
+import types
11
+from datetime import timedelta
12
+
13
+import pytest
14
+
15
+from dlm_sway.core.errors import BackendNotAvailableError
16
+from dlm_sway.core.result import ProbeResult, SuiteResult, Verdict, utcnow
17
+
18
+
19
+def _suite_with(*probes: ProbeResult) -> SuiteResult:
20
+    started = utcnow()
21
+    return SuiteResult(
22
+        spec_path="sway.yaml",
23
+        started_at=started,
24
+        finished_at=started + timedelta(seconds=1),
25
+        base_model_id="b",
26
+        adapter_id="a",
27
+        sway_version="0.1.0.dev0",
28
+        probes=probes,
29
+    )
30
+
31
+
32
+class _FakeFig:
33
+    def tight_layout(self) -> None:  # pragma: no cover — trivial
34
+        return None
35
+
36
+
37
+class _FakeAx:
38
+    def __init__(self) -> None:
39
+        self.calls: list[str] = []
40
+
41
+    def bar(self, *a, **k):  # type: ignore[no-untyped-def]
42
+        self.calls.append("bar")
43
+
44
+    def plot(self, *a, **k):  # type: ignore[no-untyped-def]
45
+        self.calls.append("plot")
46
+
47
+    def hist(self, *a, **k):  # type: ignore[no-untyped-def]
48
+        self.calls.append("hist")
49
+
50
+    def axhline(self, *a, **k):  # type: ignore[no-untyped-def]
51
+        return None
52
+
53
+    def axvline(self, *a, **k):  # type: ignore[no-untyped-def]
54
+        return None
55
+
56
+    def set_xticks(self, *a, **k):  # type: ignore[no-untyped-def]
57
+        return None
58
+
59
+    def set_xticklabels(self, *a, **k):  # type: ignore[no-untyped-def]
60
+        return None
61
+
62
+    def set_xlabel(self, *a, **k):  # type: ignore[no-untyped-def]
63
+        return None
64
+
65
+    def set_ylabel(self, *a, **k):  # type: ignore[no-untyped-def]
66
+        return None
67
+
68
+    def set_title(self, *a, **k):  # type: ignore[no-untyped-def]
69
+        return None
70
+
71
+    def legend(self, *a, **k):  # type: ignore[no-untyped-def]
72
+        return None
73
+
74
+
75
+@pytest.fixture
76
+def fake_mpl(monkeypatch: pytest.MonkeyPatch) -> _FakeAx:
77
+    ax = _FakeAx()
78
+
79
+    def _subplots(*a, **k):  # type: ignore[no-untyped-def]
80
+        return _FakeFig(), ax
81
+
82
+    plt = types.ModuleType("matplotlib.pyplot")
83
+    plt.subplots = _subplots  # type: ignore[attr-defined]
84
+    mpl_pkg = types.ModuleType("matplotlib")
85
+    monkeypatch.setitem(sys.modules, "matplotlib", mpl_pkg)
86
+    monkeypatch.setitem(sys.modules, "matplotlib.pyplot", plt)
87
+    return ax
88
+
89
+
90
+def test_section_sis_plot_uses_per_section_evidence(fake_mpl: _FakeAx) -> None:
91
+    from dlm_sway.visualize import plot_section_sis
92
+
93
+    suite = _suite_with(
94
+        ProbeResult(
95
+            name="sis",
96
+            kind="section_internalization",
97
+            verdict=Verdict.PASS,
98
+            score=0.75,
99
+            raw=0.1,
100
+            evidence={
101
+                "per_section": [
102
+                    {
103
+                        "section_id": "a",
104
+                        "kind": "prose",
105
+                        "tag": None,
106
+                        "base_nll": 3.0,
107
+                        "ft_nll": 2.5,
108
+                        "own_lift": 0.17,
109
+                        "leak_lift": 0.02,
110
+                        "effective_sis": 0.15,
111
+                        "passed": True,
112
+                    },
113
+                    {
114
+                        "section_id": "b",
115
+                        "kind": "instruction",
116
+                        "tag": "intro",
117
+                        "base_nll": 4.0,
118
+                        "ft_nll": 3.9,
119
+                        "own_lift": 0.025,
120
+                        "leak_lift": 0.03,
121
+                        "effective_sis": -0.005,
122
+                        "passed": False,
123
+                    },
124
+                ],
125
+                "per_section_threshold": 0.05,
126
+            },
127
+        )
128
+    )
129
+    plot_section_sis(suite)
130
+    assert "bar" in fake_mpl.calls
131
+
132
+
133
+def test_adapter_ablation_plot(fake_mpl: _FakeAx) -> None:
134
+    from dlm_sway.visualize import plot_adapter_ablation
135
+
136
+    suite = _suite_with(
137
+        ProbeResult(
138
+            name="abl",
139
+            kind="adapter_ablation",
140
+            verdict=Verdict.PASS,
141
+            score=0.8,
142
+            raw=0.9,
143
+            evidence={
144
+                "lambdas": [0.0, 0.5, 1.0, 1.25],
145
+                "mean_divergence_per_lambda": [0.0, 0.5, 1.0, 1.1],
146
+                "linearity": 0.91,
147
+                "saturation_lambda": 0.75,
148
+                "overshoot": 1.1,
149
+            },
150
+        )
151
+    )
152
+    plot_adapter_ablation(suite)
153
+    assert "plot" in fake_mpl.calls
154
+
155
+
156
+def test_kl_histogram_plot(fake_mpl: _FakeAx) -> None:
157
+    from dlm_sway.visualize import plot_kl_histogram
158
+
159
+    suite = _suite_with(
160
+        ProbeResult(
161
+            name="dk",
162
+            kind="delta_kl",
163
+            verdict=Verdict.PASS,
164
+            score=0.7,
165
+            raw=0.1,
166
+            evidence={"per_prompt": [0.05, 0.1, 0.12, 0.09, 0.15], "divergence_kind": "js"},
167
+        )
168
+    )
169
+    plot_kl_histogram(suite)
170
+    assert "hist" in fake_mpl.calls
171
+
172
+
173
+def test_raises_when_matplotlib_missing(monkeypatch: pytest.MonkeyPatch) -> None:
174
+    # Purge matplotlib modules and block imports.
175
+    for mod in list(sys.modules):
176
+        if mod == "matplotlib" or mod.startswith("matplotlib."):
177
+            monkeypatch.delitem(sys.modules, mod, raising=False)
178
+
179
+    import builtins
180
+
181
+    real_import = builtins.__import__
182
+
183
+    def fake_import(name: str, *a, **k):  # type: ignore[no-untyped-def]
184
+        if name == "matplotlib" or name.startswith("matplotlib."):
185
+            raise ImportError("matplotlib missing in this venv")
186
+        return real_import(name, *a, **k)
187
+
188
+    monkeypatch.setattr(builtins, "__import__", fake_import)
189
+
190
+    from dlm_sway.visualize import plot_section_sis
191
+
192
+    suite = _suite_with()
193
+    with pytest.raises(BackendNotAvailableError):
194
+        plot_section_sis(suite)
195
+
196
+
197
+def test_raises_when_no_matching_probe(fake_mpl: _FakeAx) -> None:
198
+    from dlm_sway.visualize import plot_section_sis
199
+
200
+    suite = _suite_with()  # empty — no section_internalization probe
201
+    with pytest.raises(ValueError, match="section_internalization"):
202
+        plot_section_sis(suite)