@@ -0,0 +1,179 @@ |
| 1 | +"""C1 StyleFingerprint — does ft prose *read* like the doc? |
| 2 | + |
| 3 | +Generates base and ft completions from a set of stylistic prompts, |
| 4 | +extracts a 6-dimensional fingerprint from each, and measures how the ft |
| 5 | +fingerprint has shifted **toward** the training document's own |
| 6 | +fingerprint vs the base. |
| 7 | + |
| 8 | +We compute the fingerprint with numpy-only features so the probe works |
| 9 | +out of the box without spaCy/textstat. The optional ``style`` extra |
| 10 | +upgrades the fingerprint with passive-voice rate and POS-entropy in a |
| 11 | +later milestone; the numeric contract — a non-negative vector per text |
| 12 | +— is stable across that upgrade. |
| 13 | + |
| 14 | +Signal: ``style_shift = cos(ft_fp - base_fp, doc_fp - base_fp)`` in |
| 15 | +fingerprint space. Positive values mean ft has moved *toward* the |
| 16 | +doc's style; negative values mean it moved *away* (a bad sign); |
| 17 | +near-zero means no stylistic shift detectable. |
| 18 | +""" |
| 19 | + |
| 20 | +from __future__ import annotations |
| 21 | + |
| 22 | +import re |
| 23 | +import statistics |
| 24 | +from typing import Literal |
| 25 | + |
| 26 | +import numpy as np |
| 27 | +from numpy.typing import NDArray |
| 28 | +from pydantic import Field |
| 29 | + |
| 30 | +from dlm_sway.core.result import ProbeResult, Verdict |
| 31 | +from dlm_sway.probes.base import Probe, ProbeSpec, RunContext |
| 32 | + |
| 33 | +_SENTENCE_SPLIT = re.compile(r"(?<=[.!?])\s+") |
| 34 | +_PARAGRAPH_SPLIT = re.compile(r"\n\s*\n") |
| 35 | +_WORD_RE = re.compile(r"\b[A-Za-z][A-Za-z'-]*\b") |
| 36 | +_PUNCTS = set(".,:;!?-—()[]\"'/") |
| 37 | + |
| 38 | + |
| 39 | +def fingerprint(text: str) -> NDArray[np.float64]: |
| 40 | + """Return a 6-dim stylistic fingerprint for ``text``. |
| 41 | + |
| 42 | + Dimensions (all numeric, scaled to order-1): |
| 43 | + 0. mean sentence length (words) / 30.0 |
| 44 | + 1. std sentence length (words) / 30.0 |
| 45 | + 2. type-token ratio (already in [0,1]) |
| 46 | + 3. avg word length (chars) / 10.0 |
| 47 | + 4. punctuation density per char * 10.0 |
| 48 | + 5. paragraph density (1 / avg paragraph length in words) * 30.0 |
| 49 | + """ |
| 50 | + if not text.strip(): |
| 51 | + return np.zeros(6, dtype=np.float64) |
| 52 | + |
| 53 | + sentences = [s for s in _SENTENCE_SPLIT.split(text) if s.strip()] |
| 54 | + paragraphs = [p for p in _PARAGRAPH_SPLIT.split(text) if p.strip()] |
| 55 | + words = _WORD_RE.findall(text) |
| 56 | + if not words: |
| 57 | + return np.zeros(6, dtype=np.float64) |
| 58 | + |
| 59 | + sentence_word_counts = [len(_WORD_RE.findall(s)) for s in sentences] |
| 60 | + sentence_word_counts = [c for c in sentence_word_counts if c > 0] |
| 61 | + if not sentence_word_counts: |
| 62 | + sentence_word_counts = [len(words)] |
| 63 | + |
| 64 | + mean_sent = statistics.fmean(sentence_word_counts) |
| 65 | + std_sent = statistics.pstdev(sentence_word_counts) if len(sentence_word_counts) > 1 else 0.0 |
| 66 | + ttr = len({w.lower() for w in words}) / len(words) |
| 67 | + avg_word_len = statistics.fmean(len(w) for w in words) |
| 68 | + punct_count = sum(ch in _PUNCTS for ch in text) |
| 69 | + punct_density = punct_count / max(len(text), 1) |
| 70 | + avg_paragraph_len = ( |
| 71 | + statistics.fmean(len(_WORD_RE.findall(p)) for p in paragraphs) if paragraphs else len(words) |
| 72 | + ) |
| 73 | + paragraph_density = 1.0 / max(avg_paragraph_len, 1.0) |
| 74 | + |
| 75 | + return np.asarray( |
| 76 | + [ |
| 77 | + mean_sent / 30.0, |
| 78 | + std_sent / 30.0, |
| 79 | + ttr, |
| 80 | + avg_word_len / 10.0, |
| 81 | + punct_density * 10.0, |
| 82 | + paragraph_density * 30.0, |
| 83 | + ], |
| 84 | + dtype=np.float64, |
| 85 | + ) |
| 86 | + |
| 87 | + |
| 88 | +class StyleFingerprintSpec(ProbeSpec): |
| 89 | + kind: Literal["style_fingerprint"] = "style_fingerprint" |
| 90 | + prompts: list[str] = Field(default_factory=list) |
| 91 | + """Prompts used to elicit a stylistic sample from each model.""" |
| 92 | + doc_reference: str = "" |
| 93 | + """Concatenated reference text representing the adapter's intended |
| 94 | + style. Typically the document itself; the .dlm bridge supplies this |
| 95 | + from ``ctx.doc_text`` when left empty.""" |
| 96 | + max_new_tokens: int = 128 |
| 97 | + assert_shift_gte: float = 0.25 |
| 98 | + """Minimum cosine shift for PASS. ``0.25`` is a deliberately |
| 99 | + permissive default — stylistic shift is a weaker signal than |
| 100 | + perplexity lift.""" |
| 101 | + |
| 102 | + |
| 103 | +class StyleFingerprintProbe(Probe): |
| 104 | + kind = "style_fingerprint" |
| 105 | + spec_cls = StyleFingerprintSpec |
| 106 | + category = "calibration" |
| 107 | + |
| 108 | + def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult: |
| 109 | + assert isinstance(spec, StyleFingerprintSpec) |
| 110 | + if not spec.prompts: |
| 111 | + return ProbeResult( |
| 112 | + name=spec.name, |
| 113 | + kind=spec.kind, |
| 114 | + verdict=Verdict.ERROR, |
| 115 | + score=None, |
| 116 | + message="no prompts provided", |
| 117 | + ) |
| 118 | + doc_text = spec.doc_reference or (ctx.doc_text or "") |
| 119 | + if not doc_text.strip(): |
| 120 | + return ProbeResult( |
| 121 | + name=spec.name, |
| 122 | + kind=spec.kind, |
| 123 | + verdict=Verdict.SKIP, |
| 124 | + score=None, |
| 125 | + message="no doc_reference (inline or from ctx.doc_text)", |
| 126 | + ) |
| 127 | + |
| 128 | + base_samples: list[str] = [] |
| 129 | + ft_samples: list[str] = [] |
| 130 | + for prompt in spec.prompts: |
| 131 | + with ctx.backend.as_base() as b: |
| 132 | + base_samples.append( |
| 133 | + b.generate(prompt, max_new_tokens=spec.max_new_tokens, seed=ctx.seed) |
| 134 | + ) |
| 135 | + with ctx.backend.as_finetuned() as f: |
| 136 | + ft_samples.append( |
| 137 | + f.generate(prompt, max_new_tokens=spec.max_new_tokens, seed=ctx.seed) |
| 138 | + ) |
| 139 | + |
| 140 | + base_fp = fingerprint("\n".join(base_samples)) |
| 141 | + ft_fp = fingerprint("\n".join(ft_samples)) |
| 142 | + doc_fp = fingerprint(doc_text) |
| 143 | + |
| 144 | + shift = _cosine_shift(base_fp, ft_fp, doc_fp) |
| 145 | + verdict = Verdict.PASS if shift >= spec.assert_shift_gte else Verdict.FAIL |
| 146 | + score = float(np.clip((shift + 1.0) / 2.0, 0.0, 1.0)) |
| 147 | + |
| 148 | + return ProbeResult( |
| 149 | + name=spec.name, |
| 150 | + kind=spec.kind, |
| 151 | + verdict=verdict, |
| 152 | + score=score, |
| 153 | + raw=shift, |
| 154 | + evidence={ |
| 155 | + "base_fp": base_fp.tolist(), |
| 156 | + "ft_fp": ft_fp.tolist(), |
| 157 | + "doc_fp": doc_fp.tolist(), |
| 158 | + "style_shift": shift, |
| 159 | + "weight": spec.weight, |
| 160 | + }, |
| 161 | + message=( |
| 162 | + f"style_shift={shift:+.2f} " |
| 163 | + f"({'toward' if shift > 0 else 'away from'} doc, " |
| 164 | + f"threshold={spec.assert_shift_gte})" |
| 165 | + ), |
| 166 | + ) |
| 167 | + |
| 168 | + |
| 169 | +def _cosine_shift( |
| 170 | + base: NDArray[np.float64], ft: NDArray[np.float64], doc: NDArray[np.float64] |
| 171 | +) -> float: |
| 172 | + """Cosine between (ft - base) and (doc - base) in fingerprint space.""" |
| 173 | + a = ft - base |
| 174 | + b = doc - base |
| 175 | + na = float(np.linalg.norm(a)) |
| 176 | + nb = float(np.linalg.norm(b)) |
| 177 | + if na == 0.0 or nb == 0.0: |
| 178 | + return 0.0 |
| 179 | + return float(np.dot(a, b) / (na * nb)) |