Python · 3108 bytes Raw Blame History
1 """Tests for :mod:`dlm_sway.suite.spec` + :mod:`dlm_sway.suite.loader`."""
2
3 from __future__ import annotations
4
5 from pathlib import Path
6
7 import pytest
8
9 from dlm_sway.core.errors import SpecValidationError
10 from dlm_sway.suite.loader import from_dict, load_spec
11 from dlm_sway.suite.spec import SwaySpec
12
13
14 def _minimum_valid() -> dict:
15 return {
16 "version": 1,
17 "models": {
18 "base": {"kind": "hf", "base": "HuggingFaceTB/SmolLM2-135M-Instruct"},
19 "ft": {
20 "kind": "hf",
21 "base": "HuggingFaceTB/SmolLM2-135M-Instruct",
22 "adapter": "/tmp/adapter",
23 },
24 },
25 "suite": [],
26 }
27
28
29 class TestSwaySpec:
30 def test_minimum_valid(self) -> None:
31 spec = from_dict(_minimum_valid())
32 assert isinstance(spec, SwaySpec)
33 assert spec.version == 1
34 assert spec.defaults.seed == 0
35 assert spec.defaults.differential is True
36 assert spec.suite == []
37
38 def test_rejects_unknown_top_level_keys(self) -> None:
39 data = _minimum_valid()
40 data["bogus"] = True
41 with pytest.raises(SpecValidationError) as exc_info:
42 from_dict(data)
43 assert "bogus" in str(exc_info.value).lower()
44
45 def test_rejects_future_version(self) -> None:
46 data = _minimum_valid()
47 data["version"] = 9
48 with pytest.raises(SpecValidationError, match="unsupported sway spec version"):
49 from_dict(data)
50
51 def test_defaults_frozen(self) -> None:
52 spec = from_dict(_minimum_valid())
53 from pydantic import ValidationError
54
55 with pytest.raises(ValidationError):
56 spec.defaults.seed = 99 # type: ignore[misc]
57
58
59 class TestLoader:
60 def test_missing_file(self, tmp_path: Path) -> None:
61 missing = tmp_path / "nope.yaml"
62 with pytest.raises(SpecValidationError, match="not found"):
63 load_spec(missing)
64
65 def test_invalid_yaml(self, tmp_path: Path) -> None:
66 bad = tmp_path / "bad.yaml"
67 # An unmatched { triggers yaml.scanner; a structurally ambiguous
68 # indent parses as a string value, which isn't a YAML error.
69 bad.write_text("{ unmatched: [", encoding="utf-8")
70 with pytest.raises(SpecValidationError, match="invalid YAML"):
71 load_spec(bad)
72
73 def test_non_mapping_top_level(self, tmp_path: Path) -> None:
74 bad = tmp_path / "list.yaml"
75 bad.write_text("- 1\n- 2\n", encoding="utf-8")
76 with pytest.raises(SpecValidationError, match="must be a mapping"):
77 load_spec(bad)
78
79 def test_roundtrip_via_yaml(self, tmp_path: Path) -> None:
80 import yaml
81
82 path = tmp_path / "sway.yaml"
83 path.write_text(yaml.safe_dump(_minimum_valid()), encoding="utf-8")
84 spec = load_spec(path)
85 # B22: ModelSpec.adapter is normalized via Path.resolve(), so
86 # symlinked roots (/tmp → /private/tmp on macOS) get followed.
87 # Compare against the resolved form so the test isn't host-dependent.
88 assert spec.models.ft.adapter == Path("/tmp/adapter").resolve()