Python · 6593 bytes Raw Blame History
1 """Directive walker image-extension dispatch (Sprint 35 v1).
2
3 The walker sees a `.png` / `.jpg` / `.webp` file, hands its bytes to
4 the `BlobStore`, and synthesizes a `Section(type=IMAGE, ...)` with
5 `media_path` = relpath and `media_blob_sha` = the hash. Without a
6 `blob_store` argument, image files are tallied under
7 `skipped_image_no_store` so `dlm show` can surface "would ingest N
8 images on next train" without touching disk.
9 """
10
11 from __future__ import annotations
12
13 import hashlib
14 from pathlib import Path
15
16 from dlm.directives.expand import expand_sources
17 from dlm.doc.parser import parse_text
18 from dlm.doc.sections import SectionType
19 from dlm.store.blobs import BlobStore
20
21
22 def _dlm(body: str = "") -> str:
23 return (
24 "---\n"
25 "dlm_id: 01KPMGSTNGSTTSTTSTTSTTSTVA\n"
26 "dlm_version: 10\n"
27 "base_model: smollm2-135m\n"
28 "training:\n"
29 " sources:\n"
30 f"{body}"
31 "---\n"
32 )
33
34
35 def _parse(body: str) -> object:
36 return parse_text(_dlm(body))
37
38
39 class TestImageExtensionDispatch:
40 def test_png_ingested_as_image(self, tmp_path: Path) -> None:
41 corpus = tmp_path / "corpus"
42 corpus.mkdir()
43 (corpus / "arch.png").write_bytes(b"\x89PNG\r\n\x1a\nbody")
44 parsed = _parse(f' - path: {corpus}\n include: ["**/*.png"]\n')
45 blob_store = BlobStore(tmp_path / "blobs")
46 result = expand_sources(parsed, base_path=tmp_path, blob_store=blob_store)
47
48 assert len(result.sections) == 1
49 section = result.sections[0]
50 assert section.type == SectionType.IMAGE
51 assert section.media_path == "arch.png"
52 assert section.media_alt == "arch"
53 expected_sha = hashlib.sha256(b"\x89PNG\r\n\x1a\nbody").hexdigest()
54 assert section.media_blob_sha == expected_sha
55
56 def test_missing_blob_store_counts_skip(self, tmp_path: Path) -> None:
57 corpus = tmp_path / "corpus"
58 corpus.mkdir()
59 (corpus / "hero.jpg").write_bytes(b"jpeg body")
60 parsed = _parse(f' - path: {corpus}\n include: ["**/*.jpg"]\n')
61 result = expand_sources(parsed, base_path=tmp_path, blob_store=None)
62 assert result.sections == ()
63 [prov] = result.provenance
64 assert prov.skipped_image_no_store == 1
65 assert prov.image_count == 0
66
67 def test_text_and_image_mix_in_one_directive(self, tmp_path: Path) -> None:
68 corpus = tmp_path / "corpus"
69 corpus.mkdir()
70 (corpus / "hero.png").write_bytes(b"png bytes")
71 (corpus / "notes.md").write_text("Notes.\n", encoding="utf-8")
72 parsed = _parse(f' - path: {corpus}\n include: ["**/*.png", "**/*.md"]\n')
73 blob_store = BlobStore(tmp_path / "blobs")
74 result = expand_sources(parsed, base_path=tmp_path, blob_store=blob_store)
75 kinds = [s.type for s in result.sections]
76 assert SectionType.IMAGE in kinds
77 assert SectionType.PROSE in kinds
78 [prov] = result.provenance
79 assert prov.image_count == 1
80 assert prov.file_count == 1 # prose count only
81
82 def test_identical_bytes_different_paths_distinct_section_ids(self, tmp_path: Path) -> None:
83 corpus = tmp_path / "corpus"
84 corpus.mkdir()
85 body = b"same-bytes-shared"
86 (corpus / "a.png").write_bytes(body)
87 (corpus / "b.png").write_bytes(body)
88 parsed = _parse(f' - path: {corpus}\n include: ["**/*.png"]\n')
89 blob_store = BlobStore(tmp_path / "blobs")
90 result = expand_sources(parsed, base_path=tmp_path, blob_store=blob_store)
91 assert len(result.sections) == 2
92 a, b = result.sections
93 assert a.media_blob_sha == b.media_blob_sha
94 assert a.section_id != b.section_id
95
96 def test_image_file_bypasses_binary_skip(self, tmp_path: Path) -> None:
97 # A PNG starts with a NUL-free signature but most JPEGs contain
98 # NUL in the first KiB. The text-read path's binary heuristic
99 # would skip that; the image dispatch must take precedence.
100 corpus = tmp_path / "corpus"
101 corpus.mkdir()
102 payload = b"\xff\xd8\xff\xe0\x00\x10JFIF\x00rest"
103 (corpus / "photo.jpg").write_bytes(payload)
104 parsed = _parse(f' - path: {corpus}\n include: ["**/*.jpg"]\n')
105 blob_store = BlobStore(tmp_path / "blobs")
106 result = expand_sources(parsed, base_path=tmp_path, blob_store=blob_store)
107 assert len(result.sections) == 1
108 assert result.sections[0].type == SectionType.IMAGE
109
110 def test_extension_is_case_insensitive(self, tmp_path: Path) -> None:
111 corpus = tmp_path / "corpus"
112 corpus.mkdir()
113 (corpus / "FIG.PNG").write_bytes(b"uppercase extension")
114 parsed = _parse(f' - path: {corpus}\n include: ["**/*.PNG"]\n')
115 blob_store = BlobStore(tmp_path / "blobs")
116 result = expand_sources(parsed, base_path=tmp_path, blob_store=blob_store)
117 assert len(result.sections) == 1
118 assert result.sections[0].type == SectionType.IMAGE
119
120
121 class TestImageAltDefaults:
122 def test_alt_defaults_to_stem(self, tmp_path: Path) -> None:
123 corpus = tmp_path / "corpus"
124 corpus.mkdir()
125 (corpus / "pipeline-v2.png").write_bytes(b"bytes")
126 parsed = _parse(f' - path: {corpus}\n include: ["**/*.png"]\n')
127 blob_store = BlobStore(tmp_path / "blobs")
128 result = expand_sources(parsed, base_path=tmp_path, blob_store=blob_store)
129 assert result.sections[0].media_alt == "pipeline-v2"
130
131
132 class TestImageProvenance:
133 def test_image_count_and_bytes(self, tmp_path: Path) -> None:
134 corpus = tmp_path / "corpus"
135 corpus.mkdir()
136 (corpus / "a.png").write_bytes(b"a" * 100)
137 (corpus / "b.png").write_bytes(b"b" * 200)
138 parsed = _parse(f' - path: {corpus}\n include: ["**/*.png"]\n')
139 blob_store = BlobStore(tmp_path / "blobs")
140 result = expand_sources(parsed, base_path=tmp_path, blob_store=blob_store)
141 [prov] = result.provenance
142 assert prov.image_count == 2
143 assert prov.image_bytes == 300
144
145 def test_max_files_cap_includes_images(self, tmp_path: Path) -> None:
146 corpus = tmp_path / "corpus"
147 corpus.mkdir()
148 for i in range(5):
149 (corpus / f"{i}.png").write_bytes(f"payload {i}".encode())
150 parsed = _parse(f' - path: {corpus}\n include: ["**/*.png"]\n max_files: 3\n')
151 blob_store = BlobStore(tmp_path / "blobs")
152 result = expand_sources(parsed, base_path=tmp_path, blob_store=blob_store)
153 assert len(result.sections) == 3