tenseleyflow/sway / f0b5d50

Browse files

sway(probes): C1 style_fingerprint (6-dim, numpy-only)

Authored by espadonne
SHA
f0b5d50d49353608a8d9360870092ced21daa52f
Parents
1b49ff0
Tree
dadfd7a

2 changed files

StatusFile+-
A src/dlm_sway/probes/style_fingerprint.py 179 0
A tests/unit/test_probe_style_fingerprint.py 115 0
src/dlm_sway/probes/style_fingerprint.pyadded
@@ -0,0 +1,179 @@
1
+"""C1 StyleFingerprint — does ft prose *read* like the doc?
2
+
3
+Generates base and ft completions from a set of stylistic prompts,
4
+extracts a 6-dimensional fingerprint from each, and measures how the ft
5
+fingerprint has shifted **toward** the training document's own
6
+fingerprint vs the base.
7
+
8
+We compute the fingerprint with numpy-only features so the probe works
9
+out of the box without spaCy/textstat. The optional ``style`` extra
10
+upgrades the fingerprint with passive-voice rate and POS-entropy in a
11
+later milestone; the numeric contract — a non-negative vector per text
12
+— is stable across that upgrade.
13
+
14
+Signal: ``style_shift = cos(ft_fp - base_fp, doc_fp - base_fp)`` in
15
+fingerprint space. Positive values mean ft has moved *toward* the
16
+doc's style; negative values mean it moved *away* (a bad sign);
17
+near-zero means no stylistic shift detectable.
18
+"""
19
+
20
+from __future__ import annotations
21
+
22
+import re
23
+import statistics
24
+from typing import Literal
25
+
26
+import numpy as np
27
+from numpy.typing import NDArray
28
+from pydantic import Field
29
+
30
+from dlm_sway.core.result import ProbeResult, Verdict
31
+from dlm_sway.probes.base import Probe, ProbeSpec, RunContext
32
+
33
+_SENTENCE_SPLIT = re.compile(r"(?<=[.!?])\s+")
34
+_PARAGRAPH_SPLIT = re.compile(r"\n\s*\n")
35
+_WORD_RE = re.compile(r"\b[A-Za-z][A-Za-z'-]*\b")
36
+_PUNCTS = set(".,:;!?-—()[]\"'/")
37
+
38
+
39
+def fingerprint(text: str) -> NDArray[np.float64]:
40
+    """Return a 6-dim stylistic fingerprint for ``text``.
41
+
42
+    Dimensions (all numeric, scaled to order-1):
43
+      0. mean sentence length (words)  / 30.0
44
+      1. std sentence length (words)   / 30.0
45
+      2. type-token ratio              (already in [0,1])
46
+      3. avg word length (chars)       / 10.0
47
+      4. punctuation density per char  * 10.0
48
+      5. paragraph density (1 / avg paragraph length in words) * 30.0
49
+    """
50
+    if not text.strip():
51
+        return np.zeros(6, dtype=np.float64)
52
+
53
+    sentences = [s for s in _SENTENCE_SPLIT.split(text) if s.strip()]
54
+    paragraphs = [p for p in _PARAGRAPH_SPLIT.split(text) if p.strip()]
55
+    words = _WORD_RE.findall(text)
56
+    if not words:
57
+        return np.zeros(6, dtype=np.float64)
58
+
59
+    sentence_word_counts = [len(_WORD_RE.findall(s)) for s in sentences]
60
+    sentence_word_counts = [c for c in sentence_word_counts if c > 0]
61
+    if not sentence_word_counts:
62
+        sentence_word_counts = [len(words)]
63
+
64
+    mean_sent = statistics.fmean(sentence_word_counts)
65
+    std_sent = statistics.pstdev(sentence_word_counts) if len(sentence_word_counts) > 1 else 0.0
66
+    ttr = len({w.lower() for w in words}) / len(words)
67
+    avg_word_len = statistics.fmean(len(w) for w in words)
68
+    punct_count = sum(ch in _PUNCTS for ch in text)
69
+    punct_density = punct_count / max(len(text), 1)
70
+    avg_paragraph_len = (
71
+        statistics.fmean(len(_WORD_RE.findall(p)) for p in paragraphs) if paragraphs else len(words)
72
+    )
73
+    paragraph_density = 1.0 / max(avg_paragraph_len, 1.0)
74
+
75
+    return np.asarray(
76
+        [
77
+            mean_sent / 30.0,
78
+            std_sent / 30.0,
79
+            ttr,
80
+            avg_word_len / 10.0,
81
+            punct_density * 10.0,
82
+            paragraph_density * 30.0,
83
+        ],
84
+        dtype=np.float64,
85
+    )
86
+
87
+
88
+class StyleFingerprintSpec(ProbeSpec):
89
+    kind: Literal["style_fingerprint"] = "style_fingerprint"
90
+    prompts: list[str] = Field(default_factory=list)
91
+    """Prompts used to elicit a stylistic sample from each model."""
92
+    doc_reference: str = ""
93
+    """Concatenated reference text representing the adapter's intended
94
+    style. Typically the document itself; the .dlm bridge supplies this
95
+    from ``ctx.doc_text`` when left empty."""
96
+    max_new_tokens: int = 128
97
+    assert_shift_gte: float = 0.25
98
+    """Minimum cosine shift for PASS. ``0.25`` is a deliberately
99
+    permissive default — stylistic shift is a weaker signal than
100
+    perplexity lift."""
101
+
102
+
103
+class StyleFingerprintProbe(Probe):
104
+    kind = "style_fingerprint"
105
+    spec_cls = StyleFingerprintSpec
106
+    category = "calibration"
107
+
108
+    def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult:
109
+        assert isinstance(spec, StyleFingerprintSpec)
110
+        if not spec.prompts:
111
+            return ProbeResult(
112
+                name=spec.name,
113
+                kind=spec.kind,
114
+                verdict=Verdict.ERROR,
115
+                score=None,
116
+                message="no prompts provided",
117
+            )
118
+        doc_text = spec.doc_reference or (ctx.doc_text or "")
119
+        if not doc_text.strip():
120
+            return ProbeResult(
121
+                name=spec.name,
122
+                kind=spec.kind,
123
+                verdict=Verdict.SKIP,
124
+                score=None,
125
+                message="no doc_reference (inline or from ctx.doc_text)",
126
+            )
127
+
128
+        base_samples: list[str] = []
129
+        ft_samples: list[str] = []
130
+        for prompt in spec.prompts:
131
+            with ctx.backend.as_base() as b:
132
+                base_samples.append(
133
+                    b.generate(prompt, max_new_tokens=spec.max_new_tokens, seed=ctx.seed)
134
+                )
135
+            with ctx.backend.as_finetuned() as f:
136
+                ft_samples.append(
137
+                    f.generate(prompt, max_new_tokens=spec.max_new_tokens, seed=ctx.seed)
138
+                )
139
+
140
+        base_fp = fingerprint("\n".join(base_samples))
141
+        ft_fp = fingerprint("\n".join(ft_samples))
142
+        doc_fp = fingerprint(doc_text)
143
+
144
+        shift = _cosine_shift(base_fp, ft_fp, doc_fp)
145
+        verdict = Verdict.PASS if shift >= spec.assert_shift_gte else Verdict.FAIL
146
+        score = float(np.clip((shift + 1.0) / 2.0, 0.0, 1.0))
147
+
148
+        return ProbeResult(
149
+            name=spec.name,
150
+            kind=spec.kind,
151
+            verdict=verdict,
152
+            score=score,
153
+            raw=shift,
154
+            evidence={
155
+                "base_fp": base_fp.tolist(),
156
+                "ft_fp": ft_fp.tolist(),
157
+                "doc_fp": doc_fp.tolist(),
158
+                "style_shift": shift,
159
+                "weight": spec.weight,
160
+            },
161
+            message=(
162
+                f"style_shift={shift:+.2f} "
163
+                f"({'toward' if shift > 0 else 'away from'} doc, "
164
+                f"threshold={spec.assert_shift_gte})"
165
+            ),
166
+        )
167
+
168
+
169
+def _cosine_shift(
170
+    base: NDArray[np.float64], ft: NDArray[np.float64], doc: NDArray[np.float64]
171
+) -> float:
172
+    """Cosine between (ft - base) and (doc - base) in fingerprint space."""
173
+    a = ft - base
174
+    b = doc - base
175
+    na = float(np.linalg.norm(a))
176
+    nb = float(np.linalg.norm(b))
177
+    if na == 0.0 or nb == 0.0:
178
+        return 0.0
179
+    return float(np.dot(a, b) / (na * nb))
tests/unit/test_probe_style_fingerprint.pyadded
@@ -0,0 +1,115 @@
1
+"""Tests for :mod:`dlm_sway.probes.style_fingerprint`."""
2
+
3
+from __future__ import annotations
4
+
5
+import numpy as np
6
+
7
+from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses
8
+from dlm_sway.core.result import Verdict
9
+from dlm_sway.probes.base import RunContext, build_probe
10
+from dlm_sway.probes.style_fingerprint import fingerprint
11
+
12
+
13
+class TestFingerprint:
14
+    def test_zero_vector_for_empty(self) -> None:
15
+        fp = fingerprint("")
16
+        assert fp.shape == (6,)
17
+        assert np.allclose(fp, 0.0)
18
+
19
+    def test_non_zero_for_normal_text(self) -> None:
20
+        fp = fingerprint("This is a sentence. This is another one. A third.")
21
+        assert fp.shape == (6,)
22
+        assert fp[0] > 0  # mean sentence length
23
+        assert fp[2] > 0  # TTR
24
+        assert fp[3] > 0  # avg word length
25
+
26
+    def test_distinct_styles_distinct_fingerprints(self) -> None:
27
+        terse = "Go. Now. Quick."
28
+        verbose = (
29
+            "We must, with all deliberate speed and measured consideration, "
30
+            "proceed expeditiously towards the elaborated and carefully "
31
+            "constructed resolution of the foregoing matter."
32
+        )
33
+        assert not np.allclose(fingerprint(terse), fingerprint(verbose))
34
+
35
+
36
+def _backend_with_samples(base: list[str], ft: list[str]) -> DummyDifferentialBackend:
37
+    return DummyDifferentialBackend(
38
+        base=DummyResponses(generations={f"p{i}": s for i, s in enumerate(base)}),
39
+        ft=DummyResponses(generations={f"p{i}": s for i, s in enumerate(ft)}),
40
+    )
41
+
42
+
43
+class TestProbe:
44
+    def test_pass_when_ft_drifts_toward_doc(self) -> None:
45
+        base_samples = ["Short. Plain. Words."] * 2
46
+        ft_samples = [
47
+            "Wherein many clauses conjoin themselves, through extended "
48
+            "ruminations, unto a meandering whole of considerable length."
49
+        ] * 2
50
+        doc = (
51
+            "Wherein many clauses conjoin themselves, through extended "
52
+            "ruminations, unto a meandering whole of considerable length. "
53
+            "Further elaboration, no less copious, follows apace."
54
+        )
55
+        backend = _backend_with_samples(base_samples, ft_samples)
56
+        probe, spec = build_probe(
57
+            {
58
+                "name": "c1",
59
+                "kind": "style_fingerprint",
60
+                "prompts": ["p0", "p1"],
61
+                "doc_reference": doc,
62
+                "max_new_tokens": 32,
63
+                "assert_shift_gte": 0.2,
64
+            }
65
+        )
66
+        ctx = RunContext(backend=backend)
67
+        result = probe.run(spec, ctx)
68
+        assert result.verdict == Verdict.PASS
69
+        assert result.raw is not None
70
+        assert result.raw > 0.2
71
+
72
+    def test_fail_when_no_stylistic_shift(self) -> None:
73
+        base_samples = ["Short. Plain. Words."] * 2
74
+        ft_samples = ["Short. Plain. Words."] * 2
75
+        doc = "Wherein clauses conjoin into meandering wholes of length."
76
+        backend = _backend_with_samples(base_samples, ft_samples)
77
+        probe, spec = build_probe(
78
+            {
79
+                "name": "c1",
80
+                "kind": "style_fingerprint",
81
+                "prompts": ["p0", "p1"],
82
+                "doc_reference": doc,
83
+                "assert_shift_gte": 0.25,
84
+            }
85
+        )
86
+        ctx = RunContext(backend=backend)
87
+        result = probe.run(spec, ctx)
88
+        assert result.verdict == Verdict.FAIL
89
+
90
+    def test_skip_without_doc_reference(self) -> None:
91
+        backend = _backend_with_samples(["x"], ["y"])
92
+        probe, spec = build_probe(
93
+            {
94
+                "name": "c1",
95
+                "kind": "style_fingerprint",
96
+                "prompts": ["p0"],
97
+            }
98
+        )
99
+        ctx = RunContext(backend=backend)
100
+        result = probe.run(spec, ctx)
101
+        assert result.verdict == Verdict.SKIP
102
+
103
+    def test_error_on_empty_prompts(self) -> None:
104
+        backend = _backend_with_samples([], [])
105
+        probe, spec = build_probe(
106
+            {
107
+                "name": "c1",
108
+                "kind": "style_fingerprint",
109
+                "prompts": [],
110
+                "doc_reference": "doc",
111
+            }
112
+        )
113
+        ctx = RunContext(backend=backend)
114
+        result = probe.run(spec, ctx)
115
+        assert result.verdict == Verdict.ERROR