| 1 | """Optional matplotlib-based visualizations. |
| 2 | |
| 3 | Behind the ``viz`` extra. Three functions cover the three plots that |
| 4 | make the sway report come alive in a notebook or saved PNG: |
| 5 | |
| 6 | - :func:`plot_section_sis`: per-section bar chart of effective SIS |
| 7 | (the flagship attribution view). |
| 8 | - :func:`plot_adapter_ablation`: the λ-scaled divergence curve — the |
| 9 | sway signature plot. |
| 10 | - :func:`plot_kl_histogram`: distribution of per-prompt KL divergences |
| 11 | (the raw data behind A1 DeltaKL). |
| 12 | |
| 13 | Each function raises :class:`~dlm_sway.core.errors.BackendNotAvailableError` |
| 14 | with a pip hint when matplotlib isn't installed. No function writes to |
| 15 | disk on your behalf — the caller decides (``fig.savefig(...)``). |
| 16 | """ |
| 17 | |
| 18 | from __future__ import annotations |
| 19 | |
| 20 | from typing import Any |
| 21 | |
| 22 | from dlm_sway.core.errors import BackendNotAvailableError |
| 23 | from dlm_sway.core.result import SuiteResult |
| 24 | |
| 25 | |
| 26 | def _require_mpl() -> Any: |
| 27 | try: |
| 28 | import matplotlib.pyplot as plt |
| 29 | |
| 30 | return plt |
| 31 | except ImportError as exc: |
| 32 | raise BackendNotAvailableError( |
| 33 | "visualize", |
| 34 | extra="viz", |
| 35 | hint="sway's visualization module needs matplotlib.", |
| 36 | ) from exc |
| 37 | |
| 38 | |
| 39 | def plot_section_sis(suite: SuiteResult) -> Any: |
| 40 | """Render a per-section ``effective_sis`` bar chart. |
| 41 | |
| 42 | Returns the matplotlib ``Figure``; the caller handles display / save. |
| 43 | """ |
| 44 | plt = _require_mpl() |
| 45 | |
| 46 | probe = _find_probe(suite, "section_internalization") |
| 47 | if probe is None or not probe.evidence.get("per_section"): |
| 48 | raise ValueError("suite has no section_internalization evidence to plot") |
| 49 | |
| 50 | rows: list[dict[str, Any]] = list(probe.evidence["per_section"]) |
| 51 | labels = [f"{row['tag'] or row['section_id'][:8]}\n({row['kind']})" for row in rows] |
| 52 | values = [float(row["effective_sis"]) for row in rows] |
| 53 | colors = ["#2ca02c" if row["passed"] else "#d62728" for row in rows] |
| 54 | |
| 55 | fig, ax = plt.subplots(figsize=(max(6.0, 0.7 * len(rows)), 4.0)) |
| 56 | ax.bar(range(len(rows)), values, color=colors) |
| 57 | ax.axhline( |
| 58 | float(probe.evidence.get("per_section_threshold", 0.0)), |
| 59 | color="gray", |
| 60 | linestyle="--", |
| 61 | linewidth=1, |
| 62 | label="threshold", |
| 63 | ) |
| 64 | ax.set_xticks(range(len(rows))) |
| 65 | ax.set_xticklabels(labels, rotation=30, ha="right") |
| 66 | ax.set_ylabel("effective SIS") |
| 67 | ax.set_title("Section Internalization Score") |
| 68 | ax.legend(loc="best") |
| 69 | fig.tight_layout() |
| 70 | return fig |
| 71 | |
| 72 | |
| 73 | def plot_adapter_ablation(suite: SuiteResult) -> Any: |
| 74 | """Render the signature λ-scaled divergence curve.""" |
| 75 | plt = _require_mpl() |
| 76 | |
| 77 | probe = _find_probe(suite, "adapter_ablation") |
| 78 | if probe is None or not probe.evidence.get("lambdas"): |
| 79 | raise ValueError("suite has no adapter_ablation evidence to plot") |
| 80 | |
| 81 | lambdas = list(probe.evidence["lambdas"]) |
| 82 | divs = list(probe.evidence["mean_divergence_per_lambda"]) |
| 83 | |
| 84 | fig, ax = plt.subplots(figsize=(7.0, 4.0)) |
| 85 | ax.plot(lambdas, divs, marker="o", linewidth=2, color="#1f77b4") |
| 86 | ax.axvline(1.0, color="gray", linestyle=":", linewidth=1, label="λ=1 (trained)") |
| 87 | sat = probe.evidence.get("saturation_lambda") |
| 88 | if sat is not None: |
| 89 | ax.axvline( |
| 90 | float(sat), |
| 91 | color="#2ca02c", |
| 92 | linestyle="--", |
| 93 | linewidth=1, |
| 94 | label=f"sat λ={float(sat):.2f}", |
| 95 | ) |
| 96 | ax.set_xlabel("λ (adapter scale)") |
| 97 | ax.set_ylabel("mean JS divergence vs λ=0") |
| 98 | ax.set_title( |
| 99 | f"Adapter Ablation (R²={float(probe.evidence.get('linearity', 0.0)):.2f}, " |
| 100 | f"overshoot={float(probe.evidence.get('overshoot', 0.0)):.2f})" |
| 101 | ) |
| 102 | ax.legend(loc="best") |
| 103 | fig.tight_layout() |
| 104 | return fig |
| 105 | |
| 106 | |
| 107 | def plot_kl_histogram(suite: SuiteResult) -> Any: |
| 108 | """Render the per-prompt KL distribution from a DeltaKL probe.""" |
| 109 | plt = _require_mpl() |
| 110 | |
| 111 | probe = _find_probe(suite, "delta_kl") |
| 112 | if probe is None or not probe.evidence.get("per_prompt"): |
| 113 | raise ValueError("suite has no delta_kl evidence to plot") |
| 114 | |
| 115 | values = list(probe.evidence["per_prompt"]) |
| 116 | fig, ax = plt.subplots(figsize=(7.0, 4.0)) |
| 117 | ax.hist(values, bins=max(5, min(20, len(values) // 2)), color="#ff7f0e", edgecolor="white") |
| 118 | ax.axvline( |
| 119 | float(probe.raw or 0.0), |
| 120 | color="black", |
| 121 | linestyle="--", |
| 122 | linewidth=1, |
| 123 | label=f"mean={float(probe.raw or 0.0):.3f}", |
| 124 | ) |
| 125 | ax.set_xlabel(probe.evidence.get("divergence_kind", "divergence")) |
| 126 | ax.set_ylabel("count") |
| 127 | ax.set_title("DeltaKL — per-prompt distribution") |
| 128 | ax.legend(loc="best") |
| 129 | fig.tight_layout() |
| 130 | return fig |
| 131 | |
| 132 | |
| 133 | def _find_probe(suite: SuiteResult, kind: str) -> Any: |
| 134 | for p in suite.probes: |
| 135 | if p.kind == kind: |
| 136 | return p |
| 137 | return None |