tenseleyflow/documentlanguagemodel / feb8231

Browse files

feat(inference): VL generate + loader (AutoModelForImageTextToText + processor)

Authored by mfwolffe <wolffemf@dukes.jmu.edu>
SHA
feb8231326eecac42120b300e7c5c5690c353dc9
Parents
303235a
Tree
e18b618

2 changed files

StatusFile+-
A src/dlm/inference/vl_generate.py 121 0
A src/dlm/inference/vl_loader.py 98 0
src/dlm/inference/vl_generate.pyadded
@@ -0,0 +1,121 @@
1
+"""Vision-language generation path for `dlm prompt --image`.
2
+
3
+Mirrors `dlm.inference.generate` but drives an HF `AutoProcessor`
4
+(not a bare tokenizer) + `AutoModelForImageTextToText` through a
5
+prompt that carries one or more image placeholders.
6
+
7
+Shape contract matches what TRL 1.2's
8
+`DataCollatorForVisionLanguageModeling` emits at training time: the
9
+user's text carries the base's `image_token` placeholder (e.g.
10
+`<image>`) and the processor expands each occurrence into the base's
11
+`num_image_tokens` slots. This keeps prompt-time input aligned with
12
+training-time input — the same lesson the text path learned with
13
+`format_chat_prompt`.
14
+
15
+Heavy imports (`PIL`, `torch`) defer inside the functions so importing
16
+this module stays cheap.
17
+"""
18
+
19
+from __future__ import annotations
20
+
21
+from pathlib import Path
22
+from typing import Any
23
+
24
+from dlm.inference.generate import DEFAULT_MAX_NEW_TOKENS, build_generate_kwargs
25
+
26
+
27
+def format_vl_prompt(
28
+    prompt: str,
29
+    *,
30
+    image_token: str,
31
+    num_images: int,
32
+) -> str:
33
+    """Build the VL-aware prompt text.
34
+
35
+    When the user's prompt already contains `image_token`, pass it
36
+    through — they explicitly placed the image. Otherwise prepend one
37
+    `image_token` per image so the processor can slot the pixels in
38
+    before the text; trailing newline separates the image block from
39
+    the user's question the way every VL chat template does.
40
+
41
+    This matches `sections_to_rows`' IMAGE emission at training time:
42
+    `"<image>\\n<caption>"` — training and prompt-time input see the
43
+    same token order.
44
+    """
45
+    if image_token in prompt:
46
+        return prompt
47
+    tokens = image_token * num_images
48
+    return f"{tokens}\n{prompt}" if prompt else tokens
49
+
50
+
51
+def load_images(paths: list[Path]) -> list[Any]:
52
+    """Open each path as a PIL.Image in RGB mode.
53
+
54
+    Raises `FileNotFoundError` on missing paths; a PIL `UnidentifiedImageError`
55
+    on files that aren't a decodable image. Both bubble up to the CLI
56
+    which converts them into typer exits.
57
+    """
58
+    from PIL import Image
59
+
60
+    images: list[Any] = []
61
+    for path in paths:
62
+        if not path.exists():
63
+            raise FileNotFoundError(f"image not found: {path}")
64
+        with Image.open(path) as pil:
65
+            pil.load()
66
+            images.append(pil.convert("RGB"))
67
+    return images
68
+
69
+
70
+def generate_vl(  # pragma: no cover
71
+    model: Any,
72
+    processor: Any,
73
+    prompt: str,
74
+    images: list[Any],
75
+    *,
76
+    image_token: str,
77
+    max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
78
+    temperature: float = 0.0,
79
+    top_p: float | None = None,
80
+    top_k: int | None = None,
81
+    repetition_penalty: float | None = None,
82
+) -> str:
83
+    """Render VL prompt, run generation, decode response-only tokens.
84
+
85
+    `processor` is an `AutoProcessor` for a VL base. `images` is a
86
+    list of PIL.Image objects, one per `image_token` occurrence in
87
+    `prompt` (or pre-prepended for the user). `image_token` comes from
88
+    the base's `VlPreprocessorPlan`.
89
+
90
+    Pragma'd from unit coverage because it calls `model.generate` on a
91
+    real HF VL model; covered by the slow-marked integration test.
92
+    """
93
+    import torch
94
+
95
+    formatted = format_vl_prompt(prompt, image_token=image_token, num_images=len(images))
96
+    inputs = processor(
97
+        images=images,
98
+        text=formatted,
99
+        return_tensors="pt",
100
+    ).to(model.device)
101
+    input_len = int(inputs["input_ids"].shape[-1])
102
+
103
+    gen_kwargs = build_generate_kwargs(
104
+        max_new_tokens=max_new_tokens,
105
+        temperature=temperature,
106
+        top_p=top_p,
107
+        top_k=top_k,
108
+        repetition_penalty=repetition_penalty,
109
+    )
110
+
111
+    tokenizer = getattr(processor, "tokenizer", processor)
112
+    with torch.inference_mode():
113
+        output = model.generate(
114
+            **inputs,
115
+            **gen_kwargs,
116
+            pad_token_id=tokenizer.pad_token_id,
117
+        )
118
+
119
+    response_tokens = output[0, input_len:]
120
+    decoded = tokenizer.decode(response_tokens, skip_special_tokens=True)
121
+    return str(decoded)
src/dlm/inference/vl_loader.pyadded
@@ -0,0 +1,98 @@
1
+"""VL inference loader (base + processor + adapter) for `dlm prompt --image`.
2
+
3
+Parallel to `dlm.inference.loader` — the text path loads
4
+`AutoModelForCausalLM` + a tokenizer; this path loads
5
+`AutoModelForImageTextToText` + the full `AutoProcessor` (ProcessorMixin).
6
+QLoRA is not plumbed through the VL path in v1: PaliGemma fp16 fits
7
+on 16 GB MPS, and the bitsandbytes + VL weight loading combination
8
+isn't exercised anywhere in our test matrix yet — Sprint 35.3 or a
9
+dedicated audit can thread it when the need surfaces.
10
+"""
11
+
12
+from __future__ import annotations
13
+
14
+from dataclasses import dataclass
15
+from pathlib import Path
16
+from typing import TYPE_CHECKING, Any
17
+
18
+from dlm.inference.loader import resolve_adapter_path
19
+from dlm.inference.plan import InferencePlan
20
+
21
+if TYPE_CHECKING:
22
+    from dlm.base_models import BaseModelSpec
23
+    from dlm.store.paths import StorePath
24
+
25
+
26
+@dataclass(frozen=True)
27
+class LoadedVlInference:
28
+    """Result of `load_for_vl_inference`."""
29
+
30
+    model: Any
31
+    processor: Any
32
+    plan: InferencePlan
33
+    adapter_path: Path
34
+
35
+
36
+def load_for_vl_inference(  # pragma: no cover
37
+    store: StorePath,
38
+    spec: BaseModelSpec,
39
+    caps: Any,
40
+    *,
41
+    adapter_name: str | None = None,
42
+) -> LoadedVlInference:
43
+    """Resolve plan + load VL base + adapter + processor.
44
+
45
+    Pragma'd from unit coverage: exercises `AutoModelForImageTextToText.from_pretrained`
46
+    and `AutoProcessor.from_pretrained` over real HF weights. Covered
47
+    by the Sprint 35 v1 slow integration test (T12).
48
+    """
49
+    if spec.modality != "vision-language":
50
+        raise ValueError(
51
+            f"load_for_vl_inference: {spec.key!r} is modality={spec.modality!r}; "
52
+            "use load_for_inference for text bases"
53
+        )
54
+
55
+    adapter_path = resolve_adapter_path(store, adapter_name=adapter_name)
56
+
57
+    from transformers import AutoModelForImageTextToText, AutoProcessor
58
+
59
+    from dlm.inference.plan import resolve_inference
60
+
61
+    plan = resolve_inference(adapter_path, caps)
62
+    dtype = _torch_dtype_for(plan.precision)
63
+
64
+    base = AutoModelForImageTextToText.from_pretrained(
65
+        spec.hf_id,
66
+        revision=spec.revision,
67
+        torch_dtype=dtype,
68
+        attn_implementation=plan.attn_implementation,
69
+    )
70
+
71
+    from peft import PeftModel
72
+
73
+    model = PeftModel.from_pretrained(base, str(adapter_path))
74
+    model.eval()
75
+
76
+    # Processor comes from the pinned base (not the adapter dir) because
77
+    # VL adapters don't snapshot the processor — pixel-path config is
78
+    # deterministic per base revision.
79
+    processor = AutoProcessor.from_pretrained(spec.hf_id, revision=spec.revision)
80
+
81
+    return LoadedVlInference(
82
+        model=model,
83
+        processor=processor,
84
+        plan=plan,
85
+        adapter_path=adapter_path,
86
+    )
87
+
88
+
89
+def _torch_dtype_for(precision: str) -> Any:  # pragma: no cover
90
+    try:
91
+        import torch
92
+    except ImportError:
93
+        return precision
94
+    lookup = {
95
+        "bf16": torch.bfloat16,
96
+        "fp16": torch.float16,
97
+    }
98
+    return lookup.get(precision, torch.float16)