Python · 3088 bytes Raw Blame History
1 """Tests for :mod:`dlm_sway.probes.paraphrase_invariance`."""
2
3 from __future__ import annotations
4
5 from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses
6 from dlm_sway.core.result import Verdict
7 from dlm_sway.probes.base import RunContext, build_probe
8
9
10 def _backend(*, par_lift_fraction: float, verb_lift: float = 10.0) -> DummyDifferentialBackend:
11 """Return a backend with tunable verbatim/paraphrase lifts.
12
13 The ft view adds ``verb_lift`` nats to the verbatim (Q,A) logprob
14 and ``par_lift_fraction * verb_lift`` to paraphrase logprobs.
15 """
16 base = DummyResponses(
17 logprobs={
18 ("Q", "A"): -20.0,
19 ("Q_par1", "A"): -20.0,
20 ("Q_par2", "A"): -20.0,
21 }
22 )
23 ft = DummyResponses(
24 logprobs={
25 ("Q", "A"): -20.0 + verb_lift,
26 ("Q_par1", "A"): -20.0 + par_lift_fraction * verb_lift,
27 ("Q_par2", "A"): -20.0 + par_lift_fraction * verb_lift,
28 }
29 )
30 return DummyDifferentialBackend(base=base, ft=ft)
31
32
33 def test_pass_when_generalizing() -> None:
34 # High paraphrase lift + high verbatim → healthy generalization.
35 backend = _backend(par_lift_fraction=0.9)
36 probe, spec = build_probe(
37 {
38 "name": "pi",
39 "kind": "paraphrase_invariance",
40 "intent": "generalize",
41 "min_verbatim_lift": 0.05,
42 "min_generalization_ratio": 0.5,
43 "cases": [{"prompt": "Q", "gold": "A", "paraphrases": ["Q_par1", "Q_par2"]}],
44 }
45 )
46 ctx = RunContext(backend=backend)
47 result = probe.run(spec, ctx)
48 assert result.verdict == Verdict.PASS
49 assert result.raw is not None
50 assert result.raw >= 0.5
51
52
53 def test_fails_when_only_memorized_but_intent_generalize() -> None:
54 backend = _backend(par_lift_fraction=0.0)
55 probe, spec = build_probe(
56 {
57 "name": "pi",
58 "kind": "paraphrase_invariance",
59 "intent": "generalize",
60 "min_verbatim_lift": 0.05,
61 "cases": [{"prompt": "Q", "gold": "A", "paraphrases": ["Q_par1"]}],
62 }
63 )
64 ctx = RunContext(backend=backend)
65 result = probe.run(spec, ctx)
66 assert result.verdict == Verdict.FAIL
67
68
69 def test_passes_memorize_intent_when_only_memorized() -> None:
70 backend = _backend(par_lift_fraction=0.0)
71 probe, spec = build_probe(
72 {
73 "name": "pi",
74 "kind": "paraphrase_invariance",
75 "intent": "memorize",
76 "min_verbatim_lift": 0.05,
77 "max_generalization_ratio_if_memorize": 0.3,
78 "cases": [{"prompt": "Q", "gold": "A", "paraphrases": ["Q_par1"]}],
79 }
80 )
81 ctx = RunContext(backend=backend)
82 result = probe.run(spec, ctx)
83 assert result.verdict == Verdict.PASS
84
85
86 def test_error_on_empty_cases() -> None:
87 probe, spec = build_probe({"name": "pi", "kind": "paraphrase_invariance", "cases": []})
88 backend = _backend(par_lift_fraction=0.9)
89 ctx = RunContext(backend=backend)
90 result = probe.run(spec, ctx)
91 assert result.verdict == Verdict.ERROR