tenseleyflow/documentlanguagemodel / c1a4b05

Browse files

test(inference,cli): cover format_vl_prompt, load_images, + prompt --image refusals

Authored by mfwolffe <wolffemf@dukes.jmu.edu>
SHA
c1a4b058855bd850a97c6f6f0cc8b36ca958cafd
Parents
a29e7e9
Tree
65385a2

2 changed files

StatusFile+-
A tests/unit/cli/test_prompt_image_flag.py 108 0
A tests/unit/inference/test_vl_generate.py 88 0
tests/unit/cli/test_prompt_image_flag.pyadded
@@ -0,0 +1,108 @@
1
+"""`dlm prompt --image` flag validation (Sprint 35 v1).
2
+
3
+- Passing --image to a text-base doc exits 2 with an informative message.
4
+- Omitting --image on a VL-base doc exits 2 with an actionable hint.
5
+- Both exits happen before any HF-model load, so CLI-level tests cover
6
+  them without touching torch / transformers weights.
7
+"""
8
+
9
+from __future__ import annotations
10
+
11
+from pathlib import Path
12
+
13
+from typer.testing import CliRunner
14
+
15
+from dlm.cli.app import app
16
+
17
+
18
+def _joined_output(result: object) -> str:
19
+    text = getattr(result, "output", "") + getattr(result, "stderr", "")
20
+    return " ".join(text.split())
21
+
22
+
23
+def _scaffold_text_doc(tmp_path: Path) -> Path:
24
+    """Scaffold a flat text-base doc."""
25
+    doc = tmp_path / "text.dlm"
26
+    runner = CliRunner()
27
+    result = runner.invoke(
28
+        app,
29
+        [
30
+            "--home",
31
+            str(tmp_path / "home"),
32
+            "init",
33
+            str(doc),
34
+            "--base",
35
+            "smollm2-135m",
36
+        ],
37
+    )
38
+    assert result.exit_code == 0, result.output
39
+    return doc
40
+
41
+
42
+def _scaffold_vl_doc(tmp_path: Path) -> Path:
43
+    """Scaffold a doc with the PaliGemma base pinned.
44
+
45
+    Gemma acceptance is a concern for `dlm train` / `dlm export`, not
46
+    `dlm prompt` — the resolver only enforces acceptance if the store
47
+    hasn't recorded one. We side-step by pinning base_model without
48
+    triggering a real HF download (which the CLI doesn't do).
49
+    """
50
+    doc = tmp_path / "vl.dlm"
51
+    runner = CliRunner()
52
+    result = runner.invoke(
53
+        app,
54
+        [
55
+            "--home",
56
+            str(tmp_path / "home"),
57
+            "init",
58
+            str(doc),
59
+            "--base",
60
+            "paligemma-3b-mix-224",
61
+            "--i-accept-license",
62
+        ],
63
+    )
64
+    assert result.exit_code == 0, result.output
65
+    return doc
66
+
67
+
68
+class TestTextBaseRefusesImage:
69
+    def test_text_base_with_image_exits_2(self, tmp_path: Path) -> None:
70
+        doc = _scaffold_text_doc(tmp_path)
71
+        img = tmp_path / "x.png"
72
+        img.write_bytes(b"\x89PNG fake")
73
+        runner = CliRunner()
74
+        result = runner.invoke(
75
+            app,
76
+            [
77
+                "--home",
78
+                str(tmp_path / "home"),
79
+                "prompt",
80
+                str(doc),
81
+                "hello",
82
+                "--image",
83
+                str(img),
84
+            ],
85
+        )
86
+        assert result.exit_code == 2, result.output
87
+        text = _joined_output(result)
88
+        assert "--image is only valid with vision-language bases" in text
89
+
90
+
91
+class TestVlBaseRequiresImage:
92
+    def test_vl_base_without_image_exits_2(self, tmp_path: Path) -> None:
93
+        doc = _scaffold_vl_doc(tmp_path)
94
+        runner = CliRunner()
95
+        result = runner.invoke(
96
+            app,
97
+            [
98
+                "--home",
99
+                str(tmp_path / "home"),
100
+                "prompt",
101
+                str(doc),
102
+                "hello",
103
+            ],
104
+        )
105
+        assert result.exit_code == 2, result.output
106
+        text = _joined_output(result)
107
+        assert "vision-language" in text
108
+        assert "--image" in text
tests/unit/inference/test_vl_generate.pyadded
@@ -0,0 +1,88 @@
1
+"""VL generate — prompt formatting + image loading (Sprint 35 v1).
2
+
3
+The real `generate_vl` + `load_for_vl_inference` paths are pragma'd
4
+(they need a real VL HF model); these tests cover the pure helpers.
5
+"""
6
+
7
+from __future__ import annotations
8
+
9
+from pathlib import Path
10
+
11
+import pytest
12
+from PIL import Image
13
+
14
+from dlm.inference.vl_generate import format_vl_prompt, load_images
15
+
16
+
17
+class TestFormatVlPrompt:
18
+    def test_prepends_token_before_text(self) -> None:
19
+        assert (
20
+            format_vl_prompt("describe this", image_token="<image>", num_images=1)
21
+            == "<image>\ndescribe this"
22
+        )
23
+
24
+    def test_multiple_images_repeat_token(self) -> None:
25
+        out = format_vl_prompt("compare", image_token="<image>", num_images=3)
26
+        assert out == "<image><image><image>\ncompare"
27
+
28
+    def test_empty_prompt_emits_tokens_only(self) -> None:
29
+        # Passing no text is valid — the user wants a caption for the
30
+        # image and has nothing else to say. Output is the placeholders
31
+        # alone; no trailing separator.
32
+        out = format_vl_prompt("", image_token="<image>", num_images=2)
33
+        assert out == "<image><image>"
34
+
35
+    def test_user_placed_token_respected(self) -> None:
36
+        # When the prompt already mentions the image token, we don't
37
+        # prepend anything — the user has placed it deliberately.
38
+        prompt = "Compare the before <image> and after <image> shots."
39
+        assert (
40
+            format_vl_prompt(prompt, image_token="<image>", num_images=2) == prompt
41
+        )
42
+
43
+    def test_custom_image_token(self) -> None:
44
+        out = format_vl_prompt(
45
+            "describe",
46
+            image_token="<|vision|>",
47
+            num_images=1,
48
+        )
49
+        assert out == "<|vision|>\ndescribe"
50
+
51
+
52
+class TestLoadImages:
53
+    def _write_png(self, path: Path, color: tuple[int, int, int]) -> None:
54
+        Image.new("RGB", (2, 2), color=color).save(path, format="PNG")
55
+
56
+    def test_loads_single_image(self, tmp_path: Path) -> None:
57
+        p = tmp_path / "a.png"
58
+        self._write_png(p, (255, 0, 0))
59
+        images = load_images([p])
60
+        assert len(images) == 1
61
+        assert isinstance(images[0], Image.Image)
62
+        assert images[0].mode == "RGB"
63
+
64
+    def test_loads_multiple_images_preserves_order(self, tmp_path: Path) -> None:
65
+        a = tmp_path / "a.png"
66
+        b = tmp_path / "b.png"
67
+        self._write_png(a, (255, 0, 0))
68
+        self._write_png(b, (0, 255, 0))
69
+        [first, second] = load_images([a, b])
70
+        assert first.getpixel((0, 0)) == (255, 0, 0)
71
+        assert second.getpixel((0, 0)) == (0, 255, 0)
72
+
73
+    def test_missing_file_raises_clearly(self, tmp_path: Path) -> None:
74
+        with pytest.raises(FileNotFoundError, match="image not found"):
75
+            load_images([tmp_path / "nope.png"])
76
+
77
+    def test_non_image_raises_pil_error(self, tmp_path: Path) -> None:
78
+        bogus = tmp_path / "x.png"
79
+        bogus.write_text("not an image", encoding="utf-8")
80
+        with pytest.raises(Exception):  # noqa: B017 — PIL's UnidentifiedImageError
81
+            load_images([bogus])
82
+
83
+    def test_converts_to_rgb_from_other_modes(self, tmp_path: Path) -> None:
84
+        # Save a 4-channel RGBA image; loader must still emit RGB.
85
+        p = tmp_path / "rgba.png"
86
+        Image.new("RGBA", (2, 2), color=(255, 0, 0, 128)).save(p, format="PNG")
87
+        [image] = load_images([p])
88
+        assert image.mode == "RGB"