tenseleyflow/documentlanguagemodel / 62c3366

Browse files

Cover base model edge branches

Authored by mfwolffe <wolffemf@dukes.jmu.edu>
SHA
62c33663c932d9fc185d8450258add0067c8bf56
Parents
b73a078
Tree
8b2d8c1

5 changed files

StatusFile+-
M tests/unit/base_models/test_downloader.py 22 0
M tests/unit/base_models/test_license.py 25 0
M tests/unit/base_models/test_probes.py 216 0
M tests/unit/base_models/test_registry_refresh.py 119 1
M tests/unit/base_models/test_schema.py 1 1
tests/unit/base_models/test_downloader.pymodified
@@ -8,6 +8,7 @@ from unittest.mock import patch
88
 import pytest
99
 
1010
 from dlm.base_models import BaseModelSpec, GatedModelError, download_spec, sha256_of_directory
11
+from dlm.base_models.downloader import _resolve_revision
1112
 
1213
 
1314
 def _spec() -> BaseModelSpec:
@@ -139,3 +140,24 @@ class TestDownloadSpec:
139140
             pytest.raises(RuntimeError, match="offline"),
140141
         ):
141142
             download_spec(_spec(), local_files_only=True)
143
+
144
+    def test_repository_not_found_raises_runtime_error(self) -> None:
145
+        from unittest.mock import Mock
146
+
147
+        from huggingface_hub.errors import RepositoryNotFoundError
148
+
149
+        with (
150
+            patch(
151
+                "huggingface_hub.snapshot_download",
152
+                side_effect=RepositoryNotFoundError("missing", response=Mock()),
153
+            ),
154
+            pytest.raises(RuntimeError, match="HF repository not found"),
155
+        ):
156
+            download_spec(_spec())
157
+
158
+    def test_resolve_revision_falls_back_to_expected_outside_snapshot_layout(
159
+        self, tmp_path: Path
160
+    ) -> None:
161
+        local_copy = tmp_path / "local-model"
162
+        local_copy.mkdir()
163
+        assert _resolve_revision(local_copy, "a" * 40) == "a" * 40
tests/unit/base_models/test_license.pymodified
@@ -8,6 +8,7 @@ import pytest
88
 
99
 from dlm.base_models import BASE_MODELS, GatedModelError, LicenseAcceptance
1010
 from dlm.base_models.license import is_gated, require_acceptance
11
+from dlm.base_models.schema import BaseModelSpec
1112
 
1213
 
1314
 def _non_gated_spec() -> object:
@@ -81,6 +82,30 @@ class TestRequireAcceptance:
8182
         assert acc is not None
8283
         assert acc.via == via
8384
 
85
+    def test_gated_spec_without_license_url_fails_loudly(self) -> None:
86
+        spec = BaseModelSpec.model_validate(
87
+            {
88
+                "key": "broken-gated",
89
+                "hf_id": "org/broken",
90
+                "revision": "0" * 40,
91
+                "architecture": "DemoForCausalLM",
92
+                "params": 1_000_000_000,
93
+                "target_modules": ["q_proj"],
94
+                "template": "chatml",
95
+                "gguf_arch": "demo",
96
+                "tokenizer_pre": "demo",
97
+                "license_spdx": "Other",
98
+                "license_url": None,
99
+                "requires_acceptance": True,
100
+                "redistributable": False,
101
+                "size_gb_fp16": 2.0,
102
+                "context_length": 4096,
103
+                "recommended_seq_len": 2048,
104
+            }
105
+        )
106
+        with pytest.raises(GatedModelError):
107
+            require_acceptance(spec, accept_license=True, via="cli_flag")
108
+
84109
 
85110
 class TestAcceptanceModel:
86111
     def test_frozen_rejects_mutation(self) -> None:
tests/unit/base_models/test_probes.pymodified
@@ -13,10 +13,12 @@ from dlm.base_models import BaseModelSpec, GatedModelError
1313
 from dlm.base_models.probes import (
1414
     _LLAMA_CPP_CHKTXT,
1515
     probe_architecture,
16
+    probe_audio_token,
1617
     probe_chat_template,
1718
     probe_gguf_arch_supported,
1819
     probe_pretokenizer_hash,
1920
     probe_pretokenizer_label,
21
+    probe_vl_image_token,
2022
     run_all,
2123
 )
2224
 
@@ -43,6 +45,37 @@ def _spec() -> BaseModelSpec:
4345
     )
4446
 
4547
 
48
+def _vl_spec() -> BaseModelSpec:
49
+    return BaseModelSpec.model_validate(
50
+        {
51
+            **_spec().model_dump(),
52
+            "key": "demo-vl",
53
+            "modality": "vision-language",
54
+            "vl_preprocessor_plan": {
55
+                "target_size": [224, 224],
56
+                "image_token": "<image>",
57
+                "num_image_tokens": 256,
58
+            },
59
+        }
60
+    )
61
+
62
+
63
+def _audio_spec() -> BaseModelSpec:
64
+    return BaseModelSpec.model_validate(
65
+        {
66
+            **_spec().model_dump(),
67
+            "key": "demo-audio",
68
+            "modality": "audio-language",
69
+            "audio_preprocessor_plan": {
70
+                "sample_rate": 16000,
71
+                "audio_token": "<audio>",
72
+                "num_audio_tokens": 64,
73
+                "max_length_seconds": 30.0,
74
+            },
75
+        }
76
+    )
77
+
78
+
4679
 class TestProbeArchitecture:
4780
     def test_matching_architectures_pass(self) -> None:
4881
         fake_cfg = SimpleNamespace(architectures=["DemoForCausalLM"])
@@ -105,6 +138,20 @@ class TestProbeChatTemplate:
105138
             result = probe_chat_template(_spec())
106139
         assert result.passed is False
107140
 
141
+    def test_gated_repo_raises_gated_model_error(self) -> None:
142
+        from unittest.mock import Mock
143
+
144
+        from huggingface_hub.errors import GatedRepoError
145
+
146
+        with (
147
+            patch(
148
+                "transformers.AutoTokenizer.from_pretrained",
149
+                side_effect=GatedRepoError("gated", response=Mock()),
150
+            ),
151
+            pytest.raises(GatedModelError),
152
+        ):
153
+            probe_chat_template(_spec())
154
+
108155
 
109156
 class TestProbeGgufArch:
110157
     def test_skips_when_vendor_missing(self, tmp_path: Path) -> None:
@@ -182,6 +229,16 @@ class TestProbeGgufArch:
182229
         result = probe_gguf_arch_supported(_spec(), vendor_path=vendor)
183230
         assert result.passed is True
184231
 
232
+    def test_read_error_fails(self, tmp_path: Path) -> None:
233
+        vendor = tmp_path / "llama.cpp"
234
+        vendor.mkdir()
235
+        script = vendor / "convert_hf_to_gguf.py"
236
+        script.write_text('@Model.register("DemoForCausalLM")\n', encoding="utf-8")
237
+        with patch.object(Path, "read_text", side_effect=OSError("boom")):
238
+            result = probe_gguf_arch_supported(_spec(), vendor_path=vendor)
239
+        assert result.passed is False
240
+        assert "read failed" in result.detail
241
+
185242
 
186243
 class TestProbePretokenizerLabel:
187244
     def test_skips_when_table_missing(self, tmp_path: Path) -> None:
@@ -208,6 +265,14 @@ class TestProbePretokenizerLabel:
208265
         result = probe_pretokenizer_label(_spec(), hashes_path=hashes)
209266
         assert result.passed is False
210267
 
268
+    def test_wrong_shape_table_fails(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
269
+        hashes = tmp_path / "h.json"
270
+        hashes.write_text("[]", encoding="utf-8")
271
+        monkeypatch.setattr("dlm.base_models.probes.json.loads", lambda _text: [["nested"]])
272
+        result = probe_pretokenizer_label(_spec(), hashes_path=hashes)
273
+        assert result.passed is False
274
+        assert "wrong shape" in result.detail
275
+
211276
 
212277
 class TestProbePretokenizerHash:
213278
     """Audit-04 B8: real sha256-of-token-ids fingerprint check."""
@@ -325,6 +390,127 @@ class TestProbePretokenizerHash:
325390
         assert result.passed is False
326391
         assert "wrong shape" in result.detail
327392
 
393
+    def test_tokenizer_encode_failure_fails(self, tmp_path: Path) -> None:
394
+        table = tmp_path / "fp.json"
395
+        table.write_text(json.dumps({"demo": "a" * 64}), encoding="utf-8")
396
+
397
+        with patch(
398
+            "transformers.AutoTokenizer.from_pretrained",
399
+            return_value=SimpleNamespace(
400
+                encode=lambda _text: (_ for _ in ()).throw(ValueError("boom"))
401
+            ),
402
+        ):
403
+            result = probe_pretokenizer_hash(_spec(), fingerprints_path=table)
404
+
405
+        assert result.passed is False
406
+        assert "tokenizer.encode failed" in result.detail
407
+
408
+
409
+class TestProbeVlImageToken:
410
+    def test_non_vl_spec_skips(self) -> None:
411
+        result = probe_vl_image_token(_spec())
412
+        assert result.skipped is True
413
+
414
+    def test_gated_processor_raises(self) -> None:
415
+        from unittest.mock import Mock
416
+
417
+        from huggingface_hub.errors import GatedRepoError
418
+
419
+        with (
420
+            patch(
421
+                "dlm.base_models._typed_shims.load_auto_processor",
422
+                side_effect=GatedRepoError("gated", response=Mock()),
423
+            ),
424
+            pytest.raises(GatedModelError),
425
+        ):
426
+            probe_vl_image_token(_vl_spec())
427
+
428
+    def test_missing_tokenizer_fails(self) -> None:
429
+        with patch(
430
+            "dlm.base_models._typed_shims.load_auto_processor",
431
+            return_value=SimpleNamespace(),
432
+        ):
433
+            result = probe_vl_image_token(_vl_spec())
434
+        assert result.passed is False
435
+        assert "no `.tokenizer`" in result.detail
436
+
437
+    def test_tokenizer_encode_error_fails(self) -> None:
438
+        tokenizer = SimpleNamespace(
439
+            encode=lambda _placeholder, add_special_tokens=False: (_ for _ in ()).throw(
440
+                ValueError("boom")
441
+            )
442
+        )
443
+        with patch(
444
+            "dlm.base_models._typed_shims.load_auto_processor",
445
+            return_value=SimpleNamespace(tokenizer=tokenizer),
446
+        ):
447
+            result = probe_vl_image_token(_vl_spec())
448
+        assert result.passed is False
449
+        assert "tokenizer rejected placeholder" in result.detail
450
+
451
+    def test_multi_token_placeholder_fails(self) -> None:
452
+        tokenizer = SimpleNamespace(encode=lambda _placeholder, add_special_tokens=False: [1, 2])
453
+        with patch(
454
+            "dlm.base_models._typed_shims.load_auto_processor",
455
+            return_value=SimpleNamespace(tokenizer=tokenizer),
456
+        ):
457
+            result = probe_vl_image_token(_vl_spec())
458
+        assert result.passed is False
459
+        assert "expected 1" in result.detail
460
+
461
+
462
+class TestProbeAudioToken:
463
+    def test_non_audio_spec_skips(self) -> None:
464
+        result = probe_audio_token(_spec())
465
+        assert result.skipped is True
466
+
467
+    def test_gated_processor_raises(self) -> None:
468
+        from unittest.mock import Mock
469
+
470
+        from huggingface_hub.errors import GatedRepoError
471
+
472
+        with (
473
+            patch(
474
+                "dlm.base_models._typed_shims.load_auto_processor",
475
+                side_effect=GatedRepoError("gated", response=Mock()),
476
+            ),
477
+            pytest.raises(GatedModelError),
478
+        ):
479
+            probe_audio_token(_audio_spec())
480
+
481
+    def test_missing_tokenizer_fails(self) -> None:
482
+        with patch(
483
+            "dlm.base_models._typed_shims.load_auto_processor",
484
+            return_value=SimpleNamespace(),
485
+        ):
486
+            result = probe_audio_token(_audio_spec())
487
+        assert result.passed is False
488
+        assert "no `.tokenizer`" in result.detail
489
+
490
+    def test_tokenizer_encode_error_fails(self) -> None:
491
+        tokenizer = SimpleNamespace(
492
+            encode=lambda _placeholder, add_special_tokens=False: (_ for _ in ()).throw(
493
+                ValueError("boom")
494
+            )
495
+        )
496
+        with patch(
497
+            "dlm.base_models._typed_shims.load_auto_processor",
498
+            return_value=SimpleNamespace(tokenizer=tokenizer),
499
+        ):
500
+            result = probe_audio_token(_audio_spec())
501
+        assert result.passed is False
502
+        assert "tokenizer rejected placeholder" in result.detail
503
+
504
+    def test_multi_token_placeholder_fails(self) -> None:
505
+        tokenizer = SimpleNamespace(encode=lambda _placeholder, add_special_tokens=False: [1, 2])
506
+        with patch(
507
+            "dlm.base_models._typed_shims.load_auto_processor",
508
+            return_value=SimpleNamespace(tokenizer=tokenizer),
509
+        ):
510
+            result = probe_audio_token(_audio_spec())
511
+        assert result.passed is False
512
+        assert "expected 1" in result.detail
513
+
328514
 
329515
 class TestRunAll:
330516
     def test_aggregates_all_five_probes(self) -> None:
@@ -344,3 +530,33 @@ class TestRunAll:
344530
             "pretokenizer_label",
345531
             "pretokenizer_hash",
346532
         }
533
+
534
+    def test_vl_run_all_uses_vl_probe_and_skips_export_checks(self) -> None:
535
+        spec = _vl_spec()
536
+        fake_cfg = SimpleNamespace(architectures=["DemoForCausalLM"])
537
+        tokenizer = SimpleNamespace(encode=lambda _placeholder, add_special_tokens=False: [7])
538
+        with (
539
+            patch("transformers.AutoConfig.from_pretrained", return_value=fake_cfg),
540
+            patch(
541
+                "dlm.base_models._typed_shims.load_auto_processor",
542
+                return_value=SimpleNamespace(tokenizer=tokenizer),
543
+            ),
544
+        ):
545
+            report = run_all(spec)
546
+        names = {r.name for r in report.results}
547
+        assert names == {"architecture", "vl_image_token"}
548
+
549
+    def test_audio_run_all_uses_audio_probe_and_skips_export_checks(self) -> None:
550
+        spec = _audio_spec()
551
+        fake_cfg = SimpleNamespace(architectures=["DemoForCausalLM"])
552
+        tokenizer = SimpleNamespace(encode=lambda _placeholder, add_special_tokens=False: [9])
553
+        with (
554
+            patch("transformers.AutoConfig.from_pretrained", return_value=fake_cfg),
555
+            patch(
556
+                "dlm.base_models._typed_shims.load_auto_processor",
557
+                return_value=SimpleNamespace(tokenizer=tokenizer),
558
+            ),
559
+        ):
560
+            report = run_all(spec)
561
+        names = {r.name for r in report.results}
562
+        assert names == {"architecture", "audio_token"}
tests/unit/base_models/test_registry_refresh.pymodified
@@ -3,8 +3,11 @@
33
 from __future__ import annotations
44
 
55
 from types import SimpleNamespace
6
+from urllib.error import HTTPError
67
 
7
-from dlm.base_models.registry_refresh import Drift, check_entry
8
+import pytest
9
+
10
+from dlm.base_models.registry_refresh import Drift, check_entry, check_registry, fetch_text
811
 from dlm.base_models.schema import BaseModelSpec
912
 
1013
 
@@ -37,6 +40,52 @@ class _Api:
3740
         return self._info
3841
 
3942
 
43
+class _RaisingApi:
44
+    def __init__(self, exc: Exception) -> None:
45
+        self._exc = exc
46
+
47
+    def model_info(self, _hf_id: str) -> SimpleNamespace:
48
+        raise self._exc
49
+
50
+
51
+class _FakeResponse:
52
+    def __init__(self, body: bytes, charset: str = "utf-8") -> None:
53
+        self._body = body
54
+        self.headers = SimpleNamespace(get_content_charset=lambda: charset)
55
+
56
+    def __enter__(self) -> _FakeResponse:
57
+        return self
58
+
59
+    def __exit__(self, exc_type: object, exc: object, tb: object) -> bool:
60
+        return False
61
+
62
+    def read(self) -> bytes:
63
+        return self._body
64
+
65
+
66
+class TestDrift:
67
+    def test_render_formats_each_field_on_its_own_line(self) -> None:
68
+        drift = Drift(
69
+            key="demo-1b",
70
+            hf_id="org/demo-1b",
71
+            fields=(("revision", "old", "new"), ("gating", "False", "True")),
72
+        )
73
+        assert drift.render() == (
74
+            "  demo-1b (org/demo-1b)\n"
75
+            "    revision               'old' → 'new'\n"
76
+            "    gating                 'False' → 'True'"
77
+        )
78
+
79
+
80
+class TestFetchText:
81
+    def test_fetch_text_decodes_response_body(self, monkeypatch: pytest.MonkeyPatch) -> None:
82
+        monkeypatch.setattr(
83
+            "dlm.base_models.registry_refresh.urlopen",
84
+            lambda req, timeout: _FakeResponse("olá".encode()),
85
+        )
86
+        assert fetch_text("https://example.com") == "olá"
87
+
88
+
4089
 class TestCheckEntry:
4190
     def test_no_drift_when_revision_and_gating_match(self) -> None:
4291
         spec = _spec()
@@ -80,3 +129,72 @@ class TestCheckEntry:
80129
             "official marker",
81130
             "missing from https://example.com/provenance",
82131
         ) in drift.fields
132
+
133
+    def test_gated_repo_is_reported_as_drift(self) -> None:
134
+        from unittest.mock import Mock
135
+
136
+        from huggingface_hub.errors import GatedRepoError
137
+
138
+        drift = check_entry(
139
+            _RaisingApi(GatedRepoError("gated", response=Mock())),
140
+            _spec(),
141
+        )
142
+        assert isinstance(drift, Drift)
143
+        assert ("gating", "readable", "now fully gated") in drift.fields
144
+
145
+    def test_missing_repository_is_reported_as_drift(self) -> None:
146
+        from unittest.mock import Mock
147
+
148
+        from huggingface_hub.errors import RepositoryNotFoundError
149
+
150
+        drift = check_entry(
151
+            _RaisingApi(RepositoryNotFoundError("missing", response=Mock())),
152
+            _spec(),
153
+        )
154
+        assert isinstance(drift, Drift)
155
+        assert ("repository", "present", "missing (renamed or deleted)") in drift.fields
156
+
157
+    def test_gating_mismatch_is_reported_when_enabled(self) -> None:
158
+        spec = _spec(requires_acceptance=False)
159
+        drift = check_entry(_Api(sha=spec.revision, gated=True), spec)
160
+        assert isinstance(drift, Drift)
161
+        assert ("requires_acceptance", "False", "True") in drift.fields
162
+
163
+    def test_unreachable_provenance_url_is_reported(self) -> None:
164
+        spec = _spec(
165
+            refresh_check_hf_gating=False,
166
+            provenance_url="https://example.com/provenance",
167
+            provenance_match_text="official marker",
168
+        )
169
+        drift = check_entry(
170
+            _Api(sha=spec.revision),
171
+            spec,
172
+            fetch_url_text=lambda _url: (_ for _ in ()).throw(
173
+                HTTPError(_url, 404, "missing", hdrs=None, fp=None)
174
+            ),
175
+        )
176
+        assert isinstance(drift, Drift)
177
+        assert drift.fields[0][0] == "provenance_url"
178
+        assert "unreachable" in drift.fields[0][2]
179
+
180
+
181
+class TestCheckRegistry:
182
+    def test_check_registry_collects_non_null_drifts(self, monkeypatch: pytest.MonkeyPatch) -> None:
183
+        entries = {
184
+            "one": _spec(key="one", hf_id="org/one"),
185
+            "two": _spec(key="two", hf_id="org/two"),
186
+        }
187
+        monkeypatch.setattr("dlm.base_models.registry_refresh.BASE_MODELS", entries)
188
+        monkeypatch.setattr("dlm.base_models.registry_refresh.HfApi", lambda: object())
189
+
190
+        def _fake_check_entry(
191
+            api: object, entry: BaseModelSpec, *, fetch_url_text: object
192
+        ) -> Drift | None:
193
+            if entry.key == "one":
194
+                return Drift(key="one", hf_id=entry.hf_id, fields=(("revision", "old", "new"),))
195
+            return None
196
+
197
+        monkeypatch.setattr("dlm.base_models.registry_refresh.check_entry", _fake_check_entry)
198
+        drifts = check_registry()
199
+        assert len(drifts) == 1
200
+        assert drifts[0].key == "one"
tests/unit/base_models/test_schema.pymodified
@@ -55,7 +55,7 @@ class TestRevisionValidator:
5555
 
5656
 
5757
 class TestHfIdValidator:
58
-    @pytest.mark.parametrize("bad_id", ["", "no-slash", "trailing/"])
58
+    @pytest.mark.parametrize("bad_id", ["", "no-slash", "trailing/", "org/name/extra"])
5959
     def test_invalid_hf_id_rejected(self, bad_id: str) -> None:
6060
         # `trailing/` passes the `/ in value` gate but leading slash doesn't.
6161
         # Pydantic's min_length catches empty. Bad ones without `/` raise.