tenseleyflow/sway / 332e32f

Browse files

tests/integration: end-to-end multi-turn smoke on a tiny LoRA

Authored by mfwolffe <wolffemf@dukes.jmu.edu>
SHA
332e32f298f2d628e36d10db406e968777dd9ba8
Parents
d955dfd
Tree
15180a1

1 changed file

StatusFile+-
A tests/integration/test_probe_multi_turn_coherence.py 109 0
tests/integration/test_probe_multi_turn_coherence.pyadded
@@ -0,0 +1,109 @@
1
+"""Integration test: multi_turn_coherence_decay end-to-end on a tiny LoRA.
2
+
3
+Builds a tiny random LoRA on SmolLM2-135M-Instruct (which has a real
4
+chat_template) and runs the probe through 4 turns of synthetic
5
+dialogue. The intent isn't to assert specific KL values — they
6
+depend on the random adapter — but to exercise the full code path
7
+on a real backend so a regression in the chat-template wiring,
8
+turn-loop, or curve-fit plumbing surfaces in slow CI.
9
+
10
+Marked ``slow + online``.
11
+"""
12
+
13
+from __future__ import annotations
14
+
15
+from pathlib import Path
16
+
17
+import pytest
18
+
19
+pytestmark = [pytest.mark.slow, pytest.mark.online]
20
+
21
+
22
+def _build_random_lora_adapter(base_dir: Path, out_dir: Path) -> None:
23
+    """Same shape as the other slow-lane backend tests."""
24
+    import torch
25
+    from peft import LoraConfig, get_peft_model
26
+    from transformers import AutoModelForCausalLM, AutoTokenizer
27
+
28
+    torch.manual_seed(0)
29
+    tokenizer = AutoTokenizer.from_pretrained(str(base_dir))
30
+    if tokenizer.pad_token_id is None:
31
+        tokenizer.pad_token = tokenizer.eos_token
32
+    base = AutoModelForCausalLM.from_pretrained(str(base_dir), torch_dtype=torch.float32)
33
+    cfg = LoraConfig(
34
+        r=8,
35
+        lora_alpha=16,
36
+        target_modules=["q_proj", "v_proj"],
37
+        lora_dropout=0.0,
38
+        bias="none",
39
+        task_type="CAUSAL_LM",
40
+    )
41
+    peft_model = get_peft_model(base, cfg)
42
+    with torch.no_grad():
43
+        for name, param in peft_model.named_parameters():
44
+            if "lora_B" in name:
45
+                # Tiny perturbation — base != ft, but generations stay sane
46
+                # enough to thread through 4 dialogue turns without
47
+                # collapsing to junk.
48
+                param.copy_(torch.randn_like(param) * 0.02)
49
+    peft_model.save_pretrained(str(out_dir))
50
+    tokenizer.save_pretrained(str(out_dir))
51
+
52
+
53
+@pytest.fixture(scope="module")
54
+def random_adapter(tiny_model_dir: Path, tmp_path_factory: pytest.TempPathFactory) -> Path:
55
+    out = tmp_path_factory.mktemp("multi-turn-coherence-adapter")
56
+    _build_random_lora_adapter(tiny_model_dir, out)
57
+    return out
58
+
59
+
60
+def test_probe_runs_end_to_end_on_real_adapter(tiny_model_dir: Path, random_adapter: Path) -> None:
61
+    """Smoke: HF backend + chat_template-equipped tokenizer + real
62
+    multi-turn dialogue produces a finalized result with the
63
+    documented evidence keys + finite per-turn KLs."""
64
+    from dlm_sway.backends.hf import HuggingFaceDifferentialBackend
65
+    from dlm_sway.core.model import ModelSpec
66
+    from dlm_sway.core.result import Verdict
67
+    from dlm_sway.probes.base import RunContext, build_probe
68
+
69
+    backend = HuggingFaceDifferentialBackend(
70
+        base_spec=ModelSpec(base=str(tiny_model_dir), kind="hf", dtype="fp32", device="cpu"),
71
+        adapter_path=random_adapter,
72
+    )
73
+    try:
74
+        probe, spec = build_probe(
75
+            {
76
+                "name": "mtc_smoke",
77
+                "kind": "multi_turn_coherence_decay",
78
+                "prompts": [
79
+                    "What's the difference between TCP and UDP?",
80
+                    "Explain how a neural network learns.",
81
+                ],
82
+                "max_turns": 3,
83
+                "max_new_tokens": 32,  # keep CPU runtime under control
84
+            }
85
+        )
86
+        ctx = RunContext(backend=backend, seed=0, top_k=64)
87
+        result = probe.run(spec, ctx)
88
+    finally:
89
+        backend.close()
90
+
91
+    # Shape: any verdict that isn't ERROR is fine. We don't pin
92
+    # PASS/FAIL because the random adapter's actual decay shape isn't
93
+    # under our control.
94
+    assert result.verdict in {
95
+        Verdict.PASS,
96
+        Verdict.FAIL,
97
+        Verdict.WARN,
98
+    }, result.message
99
+    assert result.evidence["max_turns"] == 3
100
+    assert result.evidence["num_prompts"] == 2
101
+    per_turn = result.evidence["per_turn_kls"]
102
+    assert len(per_turn) == 2  # max_turns - 1
103
+    for kl in per_turn:
104
+        assert isinstance(kl, float)
105
+        assert kl >= 0.0  # KL is non-negative
106
+    assert result.evidence["fit_status"] in {"ok", "stable", "non_monotonic", "degenerate"}
107
+    sparkline = result.evidence["sparkline"]
108
+    assert isinstance(sparkline, str)
109
+    assert len(sparkline) == 2