Python · 3582 bytes Raw Blame History
1 """Vision preprocessing: blob bytes → tensor via HF AutoProcessor.
2
3 Thin wrapper that runs a pre-loaded HF processor over a PIL image
4 loaded from the content-addressed blob store, with on-disk caching
5 keyed on `(blob_sha, processor_sha, target_size)`.
6
7 Callers own the processor lifecycle — `AutoProcessor.from_pretrained`
8 is expensive, so loading it once at trainer startup and reusing
9 across sections is the expected pattern. The cache does the heavy
10 lifting for repeat runs on the same corpus.
11
12 Heavy imports (`PIL`, `numpy`) happen inside the functions that
13 need them; the module is cheap to import for CLI subcommands that
14 don't touch images.
15 """
16
17 from __future__ import annotations
18
19 from dataclasses import dataclass
20 from pathlib import Path
21 from typing import Any, Final
22
23 import numpy as np
24
25 from dlm.data.vl_cache import VlCache, VlCacheKey, processor_sha256
26
27
28 @dataclass(frozen=True)
29 class PreprocessedImage:
30 """Result of running a processor over a single image.
31
32 `pixel_values` is the processor's pixel tensor shaped
33 `(num_patches, channels, height, width)` for most VL bases; some
34 (Qwen2-VL) emit variable patch counts per image. `cache_hit`
35 records whether the value came from disk so callers can surface
36 hit rates.
37 """
38
39 pixel_values: np.ndarray
40 cache_hit: bool
41
42
43 _CACHE_KEY_FACTORY: Final = VlCacheKey
44
45
46 def preprocess_image(
47 *,
48 blob_path: Path,
49 blob_sha: str,
50 processor: Any,
51 target_size: tuple[int, int],
52 cache: VlCache | None = None,
53 ) -> PreprocessedImage:
54 """Preprocess a single image blob into a pixel-values tensor.
55
56 `processor` is a pre-loaded HF processor (`AutoProcessor.from_pretrained`).
57 `target_size` is the pinned `(height, width)` from the base's
58 `VlPreprocessorPlan` — part of the cache key.
59
60 On cache hit, returns the cached array without touching the
61 processor. On miss, runs the processor and writes the result
62 back through the cache. `cache=None` bypasses caching entirely
63 (tests, ad-hoc prompts).
64 """
65 proc_sha = processor_sha256(processor)
66 key = _CACHE_KEY_FACTORY(
67 blob_sha=blob_sha,
68 processor_sha=proc_sha,
69 target_height=target_size[0],
70 target_width=target_size[1],
71 )
72
73 if cache is not None:
74 hit = cache.get(key)
75 if hit is not None:
76 return PreprocessedImage(pixel_values=hit, cache_hit=True)
77
78 tensor = _run_processor(processor, blob_path)
79
80 if cache is not None:
81 cache.put(key, tensor)
82
83 return PreprocessedImage(pixel_values=tensor, cache_hit=False)
84
85
86 def _run_processor(processor: Any, blob_path: Path) -> np.ndarray:
87 """Drive the HF processor over one image, return `pixel_values` array.
88
89 Loads the image lazily via PIL, closes it immediately after the
90 processor call so file handles don't pile up on large corpora.
91 Returns a float32 numpy array — HF processors default to torch
92 tensors when available, so the return path coerces explicitly.
93 """
94 from PIL import Image
95
96 with Image.open(blob_path) as pil_image:
97 pil_image.load()
98 rgb = pil_image.convert("RGB")
99
100 outputs = processor(images=rgb, return_tensors="np")
101 pixel_values = outputs["pixel_values"]
102 if not isinstance(pixel_values, np.ndarray):
103 # Defensive: processor honored return_tensors but wrapped as
104 # a torch tensor anyway (some versions of some processors).
105 pixel_values = np.asarray(pixel_values, dtype=np.float32)
106 result: np.ndarray = pixel_values.astype(np.float32, copy=False)
107 return result