| 1 | """Tests for :mod:`dlm_sway.probes.paraphrase_invariance`.""" |
| 2 | |
| 3 | from __future__ import annotations |
| 4 | |
| 5 | from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses |
| 6 | from dlm_sway.core.result import Verdict |
| 7 | from dlm_sway.probes.base import RunContext, build_probe |
| 8 | |
| 9 | |
| 10 | def _backend(*, par_lift_fraction: float, verb_lift: float = 10.0) -> DummyDifferentialBackend: |
| 11 | """Return a backend with tunable verbatim/paraphrase lifts. |
| 12 | |
| 13 | The ft view adds ``verb_lift`` nats to the verbatim (Q,A) logprob |
| 14 | and ``par_lift_fraction * verb_lift`` to paraphrase logprobs. |
| 15 | """ |
| 16 | base = DummyResponses( |
| 17 | logprobs={ |
| 18 | ("Q", "A"): -20.0, |
| 19 | ("Q_par1", "A"): -20.0, |
| 20 | ("Q_par2", "A"): -20.0, |
| 21 | } |
| 22 | ) |
| 23 | ft = DummyResponses( |
| 24 | logprobs={ |
| 25 | ("Q", "A"): -20.0 + verb_lift, |
| 26 | ("Q_par1", "A"): -20.0 + par_lift_fraction * verb_lift, |
| 27 | ("Q_par2", "A"): -20.0 + par_lift_fraction * verb_lift, |
| 28 | } |
| 29 | ) |
| 30 | return DummyDifferentialBackend(base=base, ft=ft) |
| 31 | |
| 32 | |
| 33 | def test_pass_when_generalizing() -> None: |
| 34 | # High paraphrase lift + high verbatim → healthy generalization. |
| 35 | backend = _backend(par_lift_fraction=0.9) |
| 36 | probe, spec = build_probe( |
| 37 | { |
| 38 | "name": "pi", |
| 39 | "kind": "paraphrase_invariance", |
| 40 | "intent": "generalize", |
| 41 | "min_verbatim_lift": 0.05, |
| 42 | "min_generalization_ratio": 0.5, |
| 43 | "cases": [{"prompt": "Q", "gold": "A", "paraphrases": ["Q_par1", "Q_par2"]}], |
| 44 | } |
| 45 | ) |
| 46 | ctx = RunContext(backend=backend) |
| 47 | result = probe.run(spec, ctx) |
| 48 | assert result.verdict == Verdict.PASS |
| 49 | assert result.raw is not None |
| 50 | assert result.raw >= 0.5 |
| 51 | |
| 52 | |
| 53 | def test_fails_when_only_memorized_but_intent_generalize() -> None: |
| 54 | backend = _backend(par_lift_fraction=0.0) |
| 55 | probe, spec = build_probe( |
| 56 | { |
| 57 | "name": "pi", |
| 58 | "kind": "paraphrase_invariance", |
| 59 | "intent": "generalize", |
| 60 | "min_verbatim_lift": 0.05, |
| 61 | "cases": [{"prompt": "Q", "gold": "A", "paraphrases": ["Q_par1"]}], |
| 62 | } |
| 63 | ) |
| 64 | ctx = RunContext(backend=backend) |
| 65 | result = probe.run(spec, ctx) |
| 66 | assert result.verdict == Verdict.FAIL |
| 67 | |
| 68 | |
| 69 | def test_passes_memorize_intent_when_only_memorized() -> None: |
| 70 | backend = _backend(par_lift_fraction=0.0) |
| 71 | probe, spec = build_probe( |
| 72 | { |
| 73 | "name": "pi", |
| 74 | "kind": "paraphrase_invariance", |
| 75 | "intent": "memorize", |
| 76 | "min_verbatim_lift": 0.05, |
| 77 | "max_generalization_ratio_if_memorize": 0.3, |
| 78 | "cases": [{"prompt": "Q", "gold": "A", "paraphrases": ["Q_par1"]}], |
| 79 | } |
| 80 | ) |
| 81 | ctx = RunContext(backend=backend) |
| 82 | result = probe.run(spec, ctx) |
| 83 | assert result.verdict == Verdict.PASS |
| 84 | |
| 85 | |
| 86 | def test_error_on_empty_cases() -> None: |
| 87 | probe, spec = build_probe({"name": "pi", "kind": "paraphrase_invariance", "cases": []}) |
| 88 | backend = _backend(par_lift_fraction=0.9) |
| 89 | ctx = RunContext(backend=backend) |
| 90 | result = probe.run(spec, ctx) |
| 91 | assert result.verdict == Verdict.ERROR |