@@ -0,0 +1,88 @@ |
| 1 | +"""VL generate — prompt formatting + image loading (Sprint 35 v1). |
| 2 | + |
| 3 | +The real `generate_vl` + `load_for_vl_inference` paths are pragma'd |
| 4 | +(they need a real VL HF model); these tests cover the pure helpers. |
| 5 | +""" |
| 6 | + |
| 7 | +from __future__ import annotations |
| 8 | + |
| 9 | +from pathlib import Path |
| 10 | + |
| 11 | +import pytest |
| 12 | +from PIL import Image |
| 13 | + |
| 14 | +from dlm.inference.vl_generate import format_vl_prompt, load_images |
| 15 | + |
| 16 | + |
| 17 | +class TestFormatVlPrompt: |
| 18 | + def test_prepends_token_before_text(self) -> None: |
| 19 | + assert ( |
| 20 | + format_vl_prompt("describe this", image_token="<image>", num_images=1) |
| 21 | + == "<image>\ndescribe this" |
| 22 | + ) |
| 23 | + |
| 24 | + def test_multiple_images_repeat_token(self) -> None: |
| 25 | + out = format_vl_prompt("compare", image_token="<image>", num_images=3) |
| 26 | + assert out == "<image><image><image>\ncompare" |
| 27 | + |
| 28 | + def test_empty_prompt_emits_tokens_only(self) -> None: |
| 29 | + # Passing no text is valid — the user wants a caption for the |
| 30 | + # image and has nothing else to say. Output is the placeholders |
| 31 | + # alone; no trailing separator. |
| 32 | + out = format_vl_prompt("", image_token="<image>", num_images=2) |
| 33 | + assert out == "<image><image>" |
| 34 | + |
| 35 | + def test_user_placed_token_respected(self) -> None: |
| 36 | + # When the prompt already mentions the image token, we don't |
| 37 | + # prepend anything — the user has placed it deliberately. |
| 38 | + prompt = "Compare the before <image> and after <image> shots." |
| 39 | + assert ( |
| 40 | + format_vl_prompt(prompt, image_token="<image>", num_images=2) == prompt |
| 41 | + ) |
| 42 | + |
| 43 | + def test_custom_image_token(self) -> None: |
| 44 | + out = format_vl_prompt( |
| 45 | + "describe", |
| 46 | + image_token="<|vision|>", |
| 47 | + num_images=1, |
| 48 | + ) |
| 49 | + assert out == "<|vision|>\ndescribe" |
| 50 | + |
| 51 | + |
| 52 | +class TestLoadImages: |
| 53 | + def _write_png(self, path: Path, color: tuple[int, int, int]) -> None: |
| 54 | + Image.new("RGB", (2, 2), color=color).save(path, format="PNG") |
| 55 | + |
| 56 | + def test_loads_single_image(self, tmp_path: Path) -> None: |
| 57 | + p = tmp_path / "a.png" |
| 58 | + self._write_png(p, (255, 0, 0)) |
| 59 | + images = load_images([p]) |
| 60 | + assert len(images) == 1 |
| 61 | + assert isinstance(images[0], Image.Image) |
| 62 | + assert images[0].mode == "RGB" |
| 63 | + |
| 64 | + def test_loads_multiple_images_preserves_order(self, tmp_path: Path) -> None: |
| 65 | + a = tmp_path / "a.png" |
| 66 | + b = tmp_path / "b.png" |
| 67 | + self._write_png(a, (255, 0, 0)) |
| 68 | + self._write_png(b, (0, 255, 0)) |
| 69 | + [first, second] = load_images([a, b]) |
| 70 | + assert first.getpixel((0, 0)) == (255, 0, 0) |
| 71 | + assert second.getpixel((0, 0)) == (0, 255, 0) |
| 72 | + |
| 73 | + def test_missing_file_raises_clearly(self, tmp_path: Path) -> None: |
| 74 | + with pytest.raises(FileNotFoundError, match="image not found"): |
| 75 | + load_images([tmp_path / "nope.png"]) |
| 76 | + |
| 77 | + def test_non_image_raises_pil_error(self, tmp_path: Path) -> None: |
| 78 | + bogus = tmp_path / "x.png" |
| 79 | + bogus.write_text("not an image", encoding="utf-8") |
| 80 | + with pytest.raises(Exception): # noqa: B017 — PIL's UnidentifiedImageError |
| 81 | + load_images([bogus]) |
| 82 | + |
| 83 | + def test_converts_to_rgb_from_other_modes(self, tmp_path: Path) -> None: |
| 84 | + # Save a 4-channel RGBA image; loader must still emit RGB. |
| 85 | + p = tmp_path / "rgba.png" |
| 86 | + Image.new("RGBA", (2, 2), color=(255, 0, 0, 128)).save(p, format="PNG") |
| 87 | + [image] = load_images([p]) |
| 88 | + assert image.mode == "RGB" |