Python · 6685 bytes Raw Blame History
1 """Slash command parser + handler matrix."""
2
3 from __future__ import annotations
4
5 from pathlib import Path
6 from unittest.mock import MagicMock
7
8 import pytest
9
10 from dlm.repl.commands import Action, _truncate, is_command, parse_and_dispatch
11 from dlm.repl.errors import BadCommandArgumentError, UnknownCommandError
12 from dlm.repl.session import ReplSession
13
14
15 def _session(**overrides: object) -> ReplSession:
16 defaults: dict[str, object] = {
17 "backend": MagicMock(name="backend"),
18 "tokenizer": MagicMock(name="tokenizer"),
19 }
20 defaults.update(overrides)
21 return ReplSession(**defaults) # type: ignore[arg-type]
22
23
24 class TestIsCommand:
25 def test_slash_leading(self) -> None:
26 assert is_command("/exit") is True
27 assert is_command(" /exit") is True
28
29 def test_plain_text_not_command(self) -> None:
30 assert is_command("hello") is False
31 assert is_command("") is False
32
33
34 class TestExitQuit:
35 def test_exit_returns_exit_action(self) -> None:
36 result = parse_and_dispatch("/exit", _session())
37 assert result.action is Action.EXIT
38
39 def test_quit_is_alias_for_exit(self) -> None:
40 result = parse_and_dispatch("/quit", _session())
41 assert result.action is Action.EXIT
42
43
44 class TestClear:
45 def test_clear_wipes_history(self) -> None:
46 s = _session()
47 s.append_user("x")
48 parse_and_dispatch("/clear", s)
49 assert s.history == []
50
51
52 class TestSave:
53 def test_save_writes_and_returns_message(self, tmp_path: Path) -> None:
54 s = _session()
55 s.append_user("hi")
56 s.append_assistant("hello")
57 out = tmp_path / "saved.json"
58 result = parse_and_dispatch(f"/save {out}", s)
59 assert out.exists()
60 assert result.message is not None
61 assert "2 messages" in result.message
62
63 def test_save_without_path_raises(self) -> None:
64 with pytest.raises(BadCommandArgumentError, match="requires a path"):
65 parse_and_dispatch("/save", _session())
66
67
68 class TestAdapter:
69 def test_single_adapter_doc_refuses(self) -> None:
70 s = _session(declared_adapters=())
71 with pytest.raises(BadCommandArgumentError, match="multi-adapter"):
72 parse_and_dispatch("/adapter knowledge", s)
73
74 def test_unknown_adapter_refused(self) -> None:
75 s = _session(declared_adapters=("knowledge", "tone"))
76 with pytest.raises(BadCommandArgumentError, match="not declared"):
77 parse_and_dispatch("/adapter ghost", s)
78
79 def test_happy_path_updates_active(self) -> None:
80 s = _session(declared_adapters=("knowledge", "tone"))
81 parse_and_dispatch("/adapter tone", s)
82 assert s.active_adapter == "tone"
83
84 def test_empty_args_raises(self) -> None:
85 with pytest.raises(BadCommandArgumentError, match="requires"):
86 parse_and_dispatch("/adapter", _session(declared_adapters=("x",)))
87
88
89 class TestParams:
90 def test_empty_args_prints_current(self) -> None:
91 s = _session()
92 result = parse_and_dispatch("/params", s)
93 assert result.message is not None
94 assert "temperature=0.7" in result.message
95
96 def test_updates_float(self) -> None:
97 s = _session()
98 parse_and_dispatch("/params temperature=0.5", s)
99 assert s.gen_params.temperature == 0.5
100
101 def test_updates_int(self) -> None:
102 s = _session()
103 parse_and_dispatch("/params max_new_tokens=128", s)
104 assert s.gen_params.max_new_tokens == 128
105
106 def test_multiple_assignments(self) -> None:
107 s = _session()
108 parse_and_dispatch("/params temperature=0.3 top_p=0.9", s)
109 assert s.gen_params.temperature == 0.3
110 assert s.gen_params.top_p == 0.9
111
112 def test_unknown_key_raises(self) -> None:
113 with pytest.raises(BadCommandArgumentError, match="unknown key"):
114 parse_and_dispatch("/params flavor=salty", _session())
115
116 def test_bad_float_raises(self) -> None:
117 with pytest.raises(BadCommandArgumentError, match="not a number"):
118 parse_and_dispatch("/params temperature=hot", _session())
119
120 def test_bad_int_raises(self) -> None:
121 with pytest.raises(BadCommandArgumentError, match="not an integer"):
122 parse_and_dispatch("/params max_new_tokens=lots", _session())
123
124 def test_missing_equals_raises(self) -> None:
125 with pytest.raises(BadCommandArgumentError, match="key=value"):
126 parse_and_dispatch("/params temperature", _session())
127
128 def test_partial_failure_leaves_prior_intact(self) -> None:
129 """Bad second entry mustn't leave first half applied."""
130 s = _session()
131 initial_temp = s.gen_params.temperature
132 with pytest.raises(BadCommandArgumentError):
133 parse_and_dispatch("/params temperature=0.5 top_p=oops", s)
134 assert s.gen_params.temperature == initial_temp
135
136
137 class TestModelAndHistory:
138 def test_model_mentions_backend_name(self) -> None:
139 backend = MagicMock()
140 backend.name = "pytorch"
141 s = _session(backend=backend)
142 result = parse_and_dispatch("/model", s)
143 assert result.message is not None
144 assert "pytorch" in result.message
145
146 def test_history_empty_message(self) -> None:
147 result = parse_and_dispatch("/history", _session())
148 assert result.message is not None
149 assert "empty" in result.message.lower()
150
151 def test_history_lists_messages(self) -> None:
152 s = _session()
153 s.append_user("q")
154 s.append_assistant("a")
155 result = parse_and_dispatch("/history", s)
156 assert result.message is not None
157 assert "user" in result.message
158 assert "assistant" in result.message
159
160
161 class TestHelp:
162 def test_help_includes_every_command(self) -> None:
163 result = parse_and_dispatch("/help", _session())
164 assert result.message is not None
165 for cmd in ("/exit", "/clear", "/save", "/adapter", "/params", "/model", "/history"):
166 assert cmd in result.message
167
168
169 class TestHelpers:
170 def test_truncate_adds_ellipsis_for_long_lines(self) -> None:
171 assert _truncate("x" * 20, 10) == "x" * 9 + "…"
172
173
174 class TestUnknownCommand:
175 def test_unknown_slash_raises(self) -> None:
176 with pytest.raises(UnknownCommandError, match="/bogus"):
177 parse_and_dispatch("/bogus", _session())
178
179 def test_bare_slash_raises(self) -> None:
180 with pytest.raises(UnknownCommandError, match="empty"):
181 parse_and_dispatch("/", _session())
182
183 def test_non_slash_line_raises(self) -> None:
184 with pytest.raises(UnknownCommandError, match="does not start"):
185 parse_and_dispatch("hello", _session())