tenseleyflow/sway / 5cf3d2e

Browse files

sway(bridge): autogen full suite from .dlm sections

Authored by espadonne
SHA
5cf3d2e17bfcc4dd48052a501482c08122326545
Parents
c5ec164
Tree
c513e09

2 changed files

StatusFile+-
M pyproject.toml 2 0
M src/dlm_sway/integrations/dlm/autogen.py 177 13
pyproject.tomlmodified
@@ -142,6 +142,8 @@ ignore = [
142142
 # PyTorch's canonical `import torch.nn.functional as F` is universally
143143
 # read, so we allow the naming exception in the HF backend only.
144144
 "src/dlm_sway/backends/hf.py" = ["N812"]
145
+# The .dlm bridge is the one place allowed to import the ``dlm`` package.
146
+"src/dlm_sway/integrations/dlm/*.py" = ["TID251"]
145147
 
146148
 [tool.ruff.lint.flake8-tidy-imports.banned-api]
147149
 # Hard architectural boundary: the `dlm` package is only importable
src/dlm_sway/integrations/dlm/autogen.pymodified
@@ -1,27 +1,191 @@
11
 """Auto-generate a ``sway.yaml`` from a ``.dlm`` document.
22
 
3
-Populated by P8 (the .dlm bridge). This module is imported lazily by
4
-``dlm-sway autogen`` so its presence doesn't fail the HF-only path. The
5
-real implementation maps :mod:`dlm.doc.sections` to sway's
6
-:class:`~dlm_sway.core.sections.Section` and emits a spec with every
7
-shipped primitive wired up.
3
+Walks the parsed sections and emits one entry per primitive sway ships:
4
+the full 11-primitive battery wired up against the document's own
5
+content. The result is a YAML artifact the user commits alongside their
6
+``.dlm`` and diffs in PRs.
7
+
8
+The generated spec includes a ``dlm_source`` field that the suite loader
9
+uses to pick up :class:`~dlm_sway.core.sections.Section` data at run
10
+time — probes that need sections (B1, B3, C3) then work against the
11
+typed structure instead of re-parsing text.
812
 """
913
 
1014
 from __future__ import annotations
1115
 
1216
 from pathlib import Path
17
+from typing import Any
18
+
19
+import yaml
1320
 
1421
 from dlm_sway.core.errors import SwayError
22
+from dlm_sway.core.sections import Section
23
+from dlm_sway.integrations.dlm.resolver import DlmHandle, resolve_dlm
1524
 
1625
 
1726
 def write_sway_yaml(dlm_path: Path, out: Path) -> None:
18
-    """Write a generated sway.yaml to ``out`` based on the .dlm at ``dlm_path``.
27
+    """Resolve the .dlm, build a spec dict, write it as YAML to ``out``."""
28
+    handle = resolve_dlm(dlm_path)
29
+    if handle.adapter_path is None:
30
+        raise SwayError(
31
+            f"{dlm_path}: no trained adapter found at ~/.dlm/store/{handle.dlm_id}/adapter; "
32
+            "train the document with `dlm train` before generating a sway suite."
33
+        )
34
+    spec = build_spec_dict(handle, dlm_source=str(dlm_path.resolve()))
35
+    out.write_text(yaml.safe_dump(spec, sort_keys=False), encoding="utf-8")
36
+
37
+
38
+def build_spec_dict(handle: DlmHandle, *, dlm_source: str | None = None) -> dict[str, Any]:
39
+    """Build a sway.yaml-shaped dict from a :class:`DlmHandle`."""
40
+    base_spec = {"kind": "hf", "base": handle.base_model}
41
+    ft_spec = {
42
+        "kind": "hf",
43
+        "base": handle.base_model,
44
+        "adapter": str(handle.adapter_path) if handle.adapter_path else None,
45
+    }
46
+    spec: dict[str, Any] = {
47
+        "version": 1,
48
+        "models": {"base": base_spec, "ft": ft_spec},
49
+        "defaults": {"seed": 0, "differential": True},
50
+        "suite": _build_suite(handle.sections),
51
+    }
52
+    if dlm_source is not None:
53
+        spec["dlm_source"] = dlm_source
54
+    return spec
55
+
56
+
57
+def _build_suite(sections: tuple[Section, ...]) -> list[dict[str, Any]]:
58
+    """Assemble the full probe battery for the given sections.
1959
 
20
-    Not yet implemented — the .dlm bridge lands in a later milestone.
60
+    The ordering matters: ``null_adapter`` first so every downstream
61
+    probe's z-score threshold has stats to consult.
2162
     """
22
-    del dlm_path, out
23
-    raise SwayError(
24
-        "dlm-sway autogen is not yet implemented — the .dlm bridge is "
25
-        "scheduled for the next milestone. Track progress at "
26
-        "https://github.com/tenseleyFlow/DocumentLanguageModel"
63
+    instruction_probes: list[tuple[str, str]] = [
64
+        (p.prompt, p.gold) for s in sections if s.kind == "instruction" for p in s.probes
65
+    ]
66
+    prose_prompts: list[str] = []
67
+    for s in sections:
68
+        if s.kind == "prose" and s.content.strip():
69
+            # Use the section's leading sentence as a natural completion prompt.
70
+            first_sentence = s.content.split(".")[0].strip()
71
+            if first_sentence:
72
+                prose_prompts.append(first_sentence + ".")
73
+
74
+    kl_prompts = [q for q, _ in instruction_probes][:16] or prose_prompts[:16]
75
+    style_prompts = prose_prompts[:8] or [q for q, _ in instruction_probes][:8]
76
+
77
+    suite: list[dict[str, Any]] = []
78
+
79
+    # Baseline calibration — always first.
80
+    suite.append({"name": "null_baseline", "kind": "null_adapter", "runs": 3})
81
+
82
+    # Adherence.
83
+    if kl_prompts:
84
+        suite.append(
85
+            {
86
+                "name": "delta_kl_doc",
87
+                "kind": "delta_kl",
88
+                "prompts": kl_prompts,
89
+                "assert_mean_gte": 0.02,
90
+            }
91
+        )
92
+    if instruction_probes:
93
+        suite.append(
94
+            {
95
+                "name": "revert_check",
96
+                "kind": "adapter_revert",
97
+                "cases": [
98
+                    {"prompt": q, "gold": a, "paraphrases": _auto_paraphrases(q)}
99
+                    for q, a in instruction_probes[:8]
100
+                ],
101
+                "assert_revert_rate_lt": 0.3,
102
+            }
103
+        )
104
+    if kl_prompts:
105
+        suite.append(
106
+            {
107
+                "name": "prompt_collapse",
108
+                "kind": "prompt_collapse",
109
+                "prompts": kl_prompts[:4],
110
+                "context_lengths": [0, 256, 512, 1024],
111
+                "assert_half_life_tokens": 300,
112
+            }
27113
         )
114
+
115
+    # Attribution.
116
+    if len(sections) >= 2:
117
+        suite.append(
118
+            {
119
+                "name": "section_attribution",
120
+                "kind": "section_internalization",
121
+                "per_section_threshold": 0.05,
122
+            }
123
+        )
124
+    if instruction_probes:
125
+        suite.append(
126
+            {
127
+                "name": "paraphrase_invariance",
128
+                "kind": "paraphrase_invariance",
129
+                "cases": [
130
+                    {"prompt": q, "gold": a, "paraphrases": _auto_paraphrases(q)}
131
+                    for q, a in instruction_probes[:6]
132
+                ],
133
+            }
134
+        )
135
+    has_preferences = any(s.kind == "preference" and s.preferences for s in sections)
136
+    if has_preferences:
137
+        suite.append(
138
+            {
139
+                "name": "preference_flip",
140
+                "kind": "preference_flip",
141
+                "assert_flip_rate_gte": 0.7,
142
+            }
143
+        )
144
+
145
+    # Calibration.
146
+    if style_prompts:
147
+        suite.append(
148
+            {
149
+                "name": "style_shift",
150
+                "kind": "style_fingerprint",
151
+                "prompts": style_prompts,
152
+            }
153
+        )
154
+    suite.append({"name": "general_knowledge", "kind": "calibration_drift"})
155
+    if any(s.kind == "prose" for s in sections):
156
+        suite.append(
157
+            {
158
+                "name": "verbatim_leak",
159
+                "kind": "leakage",
160
+                "prefix_chars": 128,
161
+                "continuation_chars": 256,
162
+            }
163
+        )
164
+
165
+    # Signature ablation — goes last because it's the most expensive.
166
+    if kl_prompts:
167
+        suite.append(
168
+            {
169
+                "name": "adapter_ablation",
170
+                "kind": "adapter_ablation",
171
+                "prompts": kl_prompts[:6],
172
+                "lambdas": [0.0, 0.25, 0.5, 0.75, 1.0, 1.25],
173
+            }
174
+        )
175
+
176
+    return suite
177
+
178
+
179
+def _auto_paraphrases(prompt: str) -> list[str]:
180
+    """Small, deterministic paraphrase set used when authors don't supply one.
181
+
182
+    Purely heuristic — good enough to detect "did the model memorize the
183
+    exact wording". Real paraphrase generation lives behind the
184
+    ``semsim`` extra.
185
+    """
186
+    variants: list[str] = []
187
+    stripped = prompt.rstrip("?. ")
188
+    variants.append(f"Could you explain: {stripped}?")
189
+    variants.append(f"I'd like to know — {stripped}.")
190
+    variants.append(f"Please describe: {stripped}.")
191
+    return variants[:3]