Python · 4703 bytes Raw Blame History
1 """Optional matplotlib-based visualizations.
2
3 Behind the ``viz`` extra. Three functions cover the three plots that
4 make the sway report come alive in a notebook or saved PNG:
5
6 - :func:`plot_section_sis`: per-section bar chart of effective SIS
7 (the flagship attribution view).
8 - :func:`plot_adapter_ablation`: the λ-scaled divergence curve — the
9 sway signature plot.
10 - :func:`plot_kl_histogram`: distribution of per-prompt KL divergences
11 (the raw data behind A1 DeltaKL).
12
13 Each function raises :class:`~dlm_sway.core.errors.BackendNotAvailableError`
14 with a pip hint when matplotlib isn't installed. No function writes to
15 disk on your behalf — the caller decides (``fig.savefig(...)``).
16 """
17
18 from __future__ import annotations
19
20 from typing import Any
21
22 from dlm_sway.core.errors import BackendNotAvailableError
23 from dlm_sway.core.result import SuiteResult
24
25
26 def _require_mpl() -> Any:
27 try:
28 import matplotlib.pyplot as plt
29
30 return plt
31 except ImportError as exc:
32 raise BackendNotAvailableError(
33 "visualize",
34 extra="viz",
35 hint="sway's visualization module needs matplotlib.",
36 ) from exc
37
38
39 def plot_section_sis(suite: SuiteResult) -> Any:
40 """Render a per-section ``effective_sis`` bar chart.
41
42 Returns the matplotlib ``Figure``; the caller handles display / save.
43 """
44 plt = _require_mpl()
45
46 probe = _find_probe(suite, "section_internalization")
47 if probe is None or not probe.evidence.get("per_section"):
48 raise ValueError("suite has no section_internalization evidence to plot")
49
50 rows: list[dict[str, Any]] = list(probe.evidence["per_section"])
51 labels = [f"{row['tag'] or row['section_id'][:8]}\n({row['kind']})" for row in rows]
52 values = [float(row["effective_sis"]) for row in rows]
53 colors = ["#2ca02c" if row["passed"] else "#d62728" for row in rows]
54
55 fig, ax = plt.subplots(figsize=(max(6.0, 0.7 * len(rows)), 4.0))
56 ax.bar(range(len(rows)), values, color=colors)
57 ax.axhline(
58 float(probe.evidence.get("per_section_threshold", 0.0)),
59 color="gray",
60 linestyle="--",
61 linewidth=1,
62 label="threshold",
63 )
64 ax.set_xticks(range(len(rows)))
65 ax.set_xticklabels(labels, rotation=30, ha="right")
66 ax.set_ylabel("effective SIS")
67 ax.set_title("Section Internalization Score")
68 ax.legend(loc="best")
69 fig.tight_layout()
70 return fig
71
72
73 def plot_adapter_ablation(suite: SuiteResult) -> Any:
74 """Render the signature λ-scaled divergence curve."""
75 plt = _require_mpl()
76
77 probe = _find_probe(suite, "adapter_ablation")
78 if probe is None or not probe.evidence.get("lambdas"):
79 raise ValueError("suite has no adapter_ablation evidence to plot")
80
81 lambdas = list(probe.evidence["lambdas"])
82 divs = list(probe.evidence["mean_divergence_per_lambda"])
83
84 fig, ax = plt.subplots(figsize=(7.0, 4.0))
85 ax.plot(lambdas, divs, marker="o", linewidth=2, color="#1f77b4")
86 ax.axvline(1.0, color="gray", linestyle=":", linewidth=1, label="λ=1 (trained)")
87 sat = probe.evidence.get("saturation_lambda")
88 if sat is not None:
89 ax.axvline(
90 float(sat),
91 color="#2ca02c",
92 linestyle="--",
93 linewidth=1,
94 label=f"sat λ={float(sat):.2f}",
95 )
96 ax.set_xlabel("λ (adapter scale)")
97 ax.set_ylabel("mean JS divergence vs λ=0")
98 ax.set_title(
99 f"Adapter Ablation (R²={float(probe.evidence.get('linearity', 0.0)):.2f}, "
100 f"overshoot={float(probe.evidence.get('overshoot', 0.0)):.2f})"
101 )
102 ax.legend(loc="best")
103 fig.tight_layout()
104 return fig
105
106
107 def plot_kl_histogram(suite: SuiteResult) -> Any:
108 """Render the per-prompt KL distribution from a DeltaKL probe."""
109 plt = _require_mpl()
110
111 probe = _find_probe(suite, "delta_kl")
112 if probe is None or not probe.evidence.get("per_prompt"):
113 raise ValueError("suite has no delta_kl evidence to plot")
114
115 values = list(probe.evidence["per_prompt"])
116 fig, ax = plt.subplots(figsize=(7.0, 4.0))
117 ax.hist(values, bins=max(5, min(20, len(values) // 2)), color="#ff7f0e", edgecolor="white")
118 ax.axvline(
119 float(probe.raw or 0.0),
120 color="black",
121 linestyle="--",
122 linewidth=1,
123 label=f"mean={float(probe.raw or 0.0):.3f}",
124 )
125 ax.set_xlabel(probe.evidence.get("divergence_kind", "divergence"))
126 ax.set_ylabel("count")
127 ax.set_title("DeltaKL — per-prompt distribution")
128 ax.legend(loc="best")
129 fig.tight_layout()
130 return fig
131
132
133 def _find_probe(suite: SuiteResult, kind: str) -> Any:
134 for p in suite.probes:
135 if p.kind == kind:
136 return p
137 return None