@@ -0,0 +1,179 @@ |
| | 1 | +"""C1 StyleFingerprint — does ft prose *read* like the doc? |
| | 2 | + |
| | 3 | +Generates base and ft completions from a set of stylistic prompts, |
| | 4 | +extracts a 6-dimensional fingerprint from each, and measures how the ft |
| | 5 | +fingerprint has shifted **toward** the training document's own |
| | 6 | +fingerprint vs the base. |
| | 7 | + |
| | 8 | +We compute the fingerprint with numpy-only features so the probe works |
| | 9 | +out of the box without spaCy/textstat. The optional ``style`` extra |
| | 10 | +upgrades the fingerprint with passive-voice rate and POS-entropy in a |
| | 11 | +later milestone; the numeric contract — a non-negative vector per text |
| | 12 | +— is stable across that upgrade. |
| | 13 | + |
| | 14 | +Signal: ``style_shift = cos(ft_fp - base_fp, doc_fp - base_fp)`` in |
| | 15 | +fingerprint space. Positive values mean ft has moved *toward* the |
| | 16 | +doc's style; negative values mean it moved *away* (a bad sign); |
| | 17 | +near-zero means no stylistic shift detectable. |
| | 18 | +""" |
| | 19 | + |
| | 20 | +from __future__ import annotations |
| | 21 | + |
| | 22 | +import re |
| | 23 | +import statistics |
| | 24 | +from typing import Literal |
| | 25 | + |
| | 26 | +import numpy as np |
| | 27 | +from numpy.typing import NDArray |
| | 28 | +from pydantic import Field |
| | 29 | + |
| | 30 | +from dlm_sway.core.result import ProbeResult, Verdict |
| | 31 | +from dlm_sway.probes.base import Probe, ProbeSpec, RunContext |
| | 32 | + |
| | 33 | +_SENTENCE_SPLIT = re.compile(r"(?<=[.!?])\s+") |
| | 34 | +_PARAGRAPH_SPLIT = re.compile(r"\n\s*\n") |
| | 35 | +_WORD_RE = re.compile(r"\b[A-Za-z][A-Za-z'-]*\b") |
| | 36 | +_PUNCTS = set(".,:;!?-—()[]\"'/") |
| | 37 | + |
| | 38 | + |
| | 39 | +def fingerprint(text: str) -> NDArray[np.float64]: |
| | 40 | + """Return a 6-dim stylistic fingerprint for ``text``. |
| | 41 | + |
| | 42 | + Dimensions (all numeric, scaled to order-1): |
| | 43 | + 0. mean sentence length (words) / 30.0 |
| | 44 | + 1. std sentence length (words) / 30.0 |
| | 45 | + 2. type-token ratio (already in [0,1]) |
| | 46 | + 3. avg word length (chars) / 10.0 |
| | 47 | + 4. punctuation density per char * 10.0 |
| | 48 | + 5. paragraph density (1 / avg paragraph length in words) * 30.0 |
| | 49 | + """ |
| | 50 | + if not text.strip(): |
| | 51 | + return np.zeros(6, dtype=np.float64) |
| | 52 | + |
| | 53 | + sentences = [s for s in _SENTENCE_SPLIT.split(text) if s.strip()] |
| | 54 | + paragraphs = [p for p in _PARAGRAPH_SPLIT.split(text) if p.strip()] |
| | 55 | + words = _WORD_RE.findall(text) |
| | 56 | + if not words: |
| | 57 | + return np.zeros(6, dtype=np.float64) |
| | 58 | + |
| | 59 | + sentence_word_counts = [len(_WORD_RE.findall(s)) for s in sentences] |
| | 60 | + sentence_word_counts = [c for c in sentence_word_counts if c > 0] |
| | 61 | + if not sentence_word_counts: |
| | 62 | + sentence_word_counts = [len(words)] |
| | 63 | + |
| | 64 | + mean_sent = statistics.fmean(sentence_word_counts) |
| | 65 | + std_sent = statistics.pstdev(sentence_word_counts) if len(sentence_word_counts) > 1 else 0.0 |
| | 66 | + ttr = len({w.lower() for w in words}) / len(words) |
| | 67 | + avg_word_len = statistics.fmean(len(w) for w in words) |
| | 68 | + punct_count = sum(ch in _PUNCTS for ch in text) |
| | 69 | + punct_density = punct_count / max(len(text), 1) |
| | 70 | + avg_paragraph_len = ( |
| | 71 | + statistics.fmean(len(_WORD_RE.findall(p)) for p in paragraphs) if paragraphs else len(words) |
| | 72 | + ) |
| | 73 | + paragraph_density = 1.0 / max(avg_paragraph_len, 1.0) |
| | 74 | + |
| | 75 | + return np.asarray( |
| | 76 | + [ |
| | 77 | + mean_sent / 30.0, |
| | 78 | + std_sent / 30.0, |
| | 79 | + ttr, |
| | 80 | + avg_word_len / 10.0, |
| | 81 | + punct_density * 10.0, |
| | 82 | + paragraph_density * 30.0, |
| | 83 | + ], |
| | 84 | + dtype=np.float64, |
| | 85 | + ) |
| | 86 | + |
| | 87 | + |
| | 88 | +class StyleFingerprintSpec(ProbeSpec): |
| | 89 | + kind: Literal["style_fingerprint"] = "style_fingerprint" |
| | 90 | + prompts: list[str] = Field(default_factory=list) |
| | 91 | + """Prompts used to elicit a stylistic sample from each model.""" |
| | 92 | + doc_reference: str = "" |
| | 93 | + """Concatenated reference text representing the adapter's intended |
| | 94 | + style. Typically the document itself; the .dlm bridge supplies this |
| | 95 | + from ``ctx.doc_text`` when left empty.""" |
| | 96 | + max_new_tokens: int = 128 |
| | 97 | + assert_shift_gte: float = 0.25 |
| | 98 | + """Minimum cosine shift for PASS. ``0.25`` is a deliberately |
| | 99 | + permissive default — stylistic shift is a weaker signal than |
| | 100 | + perplexity lift.""" |
| | 101 | + |
| | 102 | + |
| | 103 | +class StyleFingerprintProbe(Probe): |
| | 104 | + kind = "style_fingerprint" |
| | 105 | + spec_cls = StyleFingerprintSpec |
| | 106 | + category = "calibration" |
| | 107 | + |
| | 108 | + def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult: |
| | 109 | + assert isinstance(spec, StyleFingerprintSpec) |
| | 110 | + if not spec.prompts: |
| | 111 | + return ProbeResult( |
| | 112 | + name=spec.name, |
| | 113 | + kind=spec.kind, |
| | 114 | + verdict=Verdict.ERROR, |
| | 115 | + score=None, |
| | 116 | + message="no prompts provided", |
| | 117 | + ) |
| | 118 | + doc_text = spec.doc_reference or (ctx.doc_text or "") |
| | 119 | + if not doc_text.strip(): |
| | 120 | + return ProbeResult( |
| | 121 | + name=spec.name, |
| | 122 | + kind=spec.kind, |
| | 123 | + verdict=Verdict.SKIP, |
| | 124 | + score=None, |
| | 125 | + message="no doc_reference (inline or from ctx.doc_text)", |
| | 126 | + ) |
| | 127 | + |
| | 128 | + base_samples: list[str] = [] |
| | 129 | + ft_samples: list[str] = [] |
| | 130 | + for prompt in spec.prompts: |
| | 131 | + with ctx.backend.as_base() as b: |
| | 132 | + base_samples.append( |
| | 133 | + b.generate(prompt, max_new_tokens=spec.max_new_tokens, seed=ctx.seed) |
| | 134 | + ) |
| | 135 | + with ctx.backend.as_finetuned() as f: |
| | 136 | + ft_samples.append( |
| | 137 | + f.generate(prompt, max_new_tokens=spec.max_new_tokens, seed=ctx.seed) |
| | 138 | + ) |
| | 139 | + |
| | 140 | + base_fp = fingerprint("\n".join(base_samples)) |
| | 141 | + ft_fp = fingerprint("\n".join(ft_samples)) |
| | 142 | + doc_fp = fingerprint(doc_text) |
| | 143 | + |
| | 144 | + shift = _cosine_shift(base_fp, ft_fp, doc_fp) |
| | 145 | + verdict = Verdict.PASS if shift >= spec.assert_shift_gte else Verdict.FAIL |
| | 146 | + score = float(np.clip((shift + 1.0) / 2.0, 0.0, 1.0)) |
| | 147 | + |
| | 148 | + return ProbeResult( |
| | 149 | + name=spec.name, |
| | 150 | + kind=spec.kind, |
| | 151 | + verdict=verdict, |
| | 152 | + score=score, |
| | 153 | + raw=shift, |
| | 154 | + evidence={ |
| | 155 | + "base_fp": base_fp.tolist(), |
| | 156 | + "ft_fp": ft_fp.tolist(), |
| | 157 | + "doc_fp": doc_fp.tolist(), |
| | 158 | + "style_shift": shift, |
| | 159 | + "weight": spec.weight, |
| | 160 | + }, |
| | 161 | + message=( |
| | 162 | + f"style_shift={shift:+.2f} " |
| | 163 | + f"({'toward' if shift > 0 else 'away from'} doc, " |
| | 164 | + f"threshold={spec.assert_shift_gte})" |
| | 165 | + ), |
| | 166 | + ) |
| | 167 | + |
| | 168 | + |
| | 169 | +def _cosine_shift( |
| | 170 | + base: NDArray[np.float64], ft: NDArray[np.float64], doc: NDArray[np.float64] |
| | 171 | +) -> float: |
| | 172 | + """Cosine between (ft - base) and (doc - base) in fingerprint space.""" |
| | 173 | + a = ft - base |
| | 174 | + b = doc - base |
| | 175 | + na = float(np.linalg.norm(a)) |
| | 176 | + nb = float(np.linalg.norm(b)) |
| | 177 | + if na == 0.0 or nb == 0.0: |
| | 178 | + return 0.0 |
| | 179 | + return float(np.dot(a, b) / (na * nb)) |