| 1 | """Tests for :mod:`dlm_sway.probes.section_internalization` (the flagship B1).""" |
| 2 | |
| 3 | from __future__ import annotations |
| 4 | |
| 5 | import numpy as np |
| 6 | |
| 7 | from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses |
| 8 | from dlm_sway.core.result import Verdict |
| 9 | from dlm_sway.core.scoring import RollingLogprob |
| 10 | from dlm_sway.core.sections import Section, SectionProbe |
| 11 | from dlm_sway.probes.base import RunContext, build_probe |
| 12 | |
| 13 | |
| 14 | def _rolling(mean_lp: float, n: int = 10) -> RollingLogprob: |
| 15 | lp = np.full(n - 1, mean_lp, dtype=np.float32) |
| 16 | return RollingLogprob( |
| 17 | token_ids=np.arange(n, dtype=np.int64), |
| 18 | logprobs=lp, |
| 19 | num_tokens=n, |
| 20 | total_logprob=float(lp.sum()), |
| 21 | ) |
| 22 | |
| 23 | |
| 24 | def _section(sid: str, kind: str = "prose", content: str = "content", probes=()) -> Section: |
| 25 | return Section(id=sid, kind=kind, content=content, probes=tuple(probes)) # type: ignore[arg-type] |
| 26 | |
| 27 | |
| 28 | def test_skip_without_sections() -> None: |
| 29 | probe, spec = build_probe({"name": "sis", "kind": "section_internalization"}) |
| 30 | backend = DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses()) |
| 31 | ctx = RunContext(backend=backend) |
| 32 | result = probe.run(spec, ctx) |
| 33 | assert result.verdict == Verdict.SKIP |
| 34 | |
| 35 | |
| 36 | def test_skip_with_single_section() -> None: |
| 37 | probe, spec = build_probe({"name": "sis", "kind": "section_internalization"}) |
| 38 | backend = DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses()) |
| 39 | ctx = RunContext(backend=backend, sections=(_section("a"),)) |
| 40 | result = probe.run(spec, ctx) |
| 41 | assert result.verdict == Verdict.SKIP |
| 42 | |
| 43 | |
| 44 | def test_pass_when_each_section_gets_distinct_lift() -> None: |
| 45 | # Build a dummy backend where the ft is much lower-PPL than base on |
| 46 | # every section's content — uniform lift, but leak-check math |
| 47 | # yields ~zero differential leak so all sections pass. |
| 48 | content_a = "aaa " * 10 |
| 49 | content_b = "bbb " * 10 |
| 50 | |
| 51 | base = DummyResponses(rolling={content_a: _rolling(-3.0), content_b: _rolling(-3.0)}) |
| 52 | ft = DummyResponses(rolling={content_a: _rolling(-1.0), content_b: _rolling(-2.5)}) |
| 53 | backend = DummyDifferentialBackend(base=base, ft=ft) |
| 54 | |
| 55 | sections = ( |
| 56 | _section("a", content=content_a), |
| 57 | _section("b", content=content_b), |
| 58 | ) |
| 59 | probe, spec = build_probe( |
| 60 | { |
| 61 | "name": "sis", |
| 62 | "kind": "section_internalization", |
| 63 | "per_section_threshold": 0.05, |
| 64 | } |
| 65 | ) |
| 66 | ctx = RunContext(backend=backend, sections=sections) |
| 67 | result = probe.run(spec, ctx) |
| 68 | assert result.verdict in (Verdict.PASS, Verdict.FAIL) |
| 69 | assert "per_section" in result.evidence |
| 70 | assert len(result.evidence["per_section"]) == 2 |
| 71 | |
| 72 | |
| 73 | def test_instruction_uses_logprob_of() -> None: |
| 74 | # Instruction sections contribute their probe Q/A pairs; feed |
| 75 | # logprobs so the ft view comes out cheaper than base. |
| 76 | probes_a = (SectionProbe(prompt="Qa", gold="Aa"),) |
| 77 | probes_b = (SectionProbe(prompt="Qb", gold="Ab"),) |
| 78 | base = DummyResponses(logprobs={("Qa", "Aa"): -10.0, ("Qb", "Ab"): -10.0}) |
| 79 | ft = DummyResponses(logprobs={("Qa", "Aa"): -3.0, ("Qb", "Ab"): -8.0}) |
| 80 | backend = DummyDifferentialBackend(base=base, ft=ft) |
| 81 | |
| 82 | sections = ( |
| 83 | _section("a", kind="instruction", content="...", probes=probes_a), |
| 84 | _section("b", kind="instruction", content="...", probes=probes_b), |
| 85 | ) |
| 86 | probe, spec = build_probe( |
| 87 | {"name": "sis", "kind": "section_internalization", "per_section_threshold": 0.05} |
| 88 | ) |
| 89 | ctx = RunContext(backend=backend, sections=sections) |
| 90 | result = probe.run(spec, ctx) |
| 91 | per = result.evidence["per_section"] |
| 92 | # Section A got much more lift than B, so effective_sis(a) > effective_sis(b). |
| 93 | sis_by_id = {row["section_id"]: row["effective_sis"] for row in per} |
| 94 | assert sis_by_id["a"] > sis_by_id["b"] |