Python · 13426 bytes Raw Blame History
1 """Runtime coverage for the peer share transport."""
2
3 from __future__ import annotations
4
5 import importlib
6 import socket
7 from io import BytesIO
8 from pathlib import Path
9 from types import SimpleNamespace
10
11 import pytest
12
13 from dlm.share.errors import PeerAuthError, RateLimitError
14 from dlm.share.peer import (
15 RateLimiter,
16 ServeHandle,
17 ServeOptions,
18 _detect_lan_ip,
19 _log_connection,
20 build_handler,
21 new_session,
22 pull_peer,
23 serve,
24 )
25
26 peer_mod = importlib.import_module("dlm.share.peer")
27
28
29 def _build_test_handler(
30 tmp_path: Path,
31 *,
32 path: str,
33 ) -> tuple[type[object], object, list[tuple[str, str, str, str]], RateLimiter, Path]:
34 session = new_session("01HZPEER")
35 pack_path = tmp_path / "bundle.dlm.pack"
36 pack_path.write_bytes(b"peer-pack")
37 rate_limiter = RateLimiter(max_concurrency=4, rate_limit_per_min=30)
38 logs: list[tuple[str, str, str, str]] = []
39
40 handler_cls = build_handler(session, pack_path, rate_limiter)
41 handler = object.__new__(handler_cls)
42 handler.path = path
43 handler.client_address = ("127.0.0.1", 7337)
44 handler.send_error = lambda code, message: logs.append(("error", str(code), message, "")) # type: ignore[attr-defined]
45 handler._stream_pack = lambda path: logs.append(("stream", str(path), "", "")) # type: ignore[attr-defined]
46 return handler_cls, handler, logs, rate_limiter, pack_path
47
48
49 class TestPeerHandler:
50 def test_log_message_is_silent(self, tmp_path: Path) -> None:
51 handler_cls, handler, _logs, _rate_limiter, _pack_path = _build_test_handler(
52 tmp_path, path="/ignored"
53 )
54 assert handler_cls.log_message(handler, "%s", "ignored") is None
55
56 def test_handler_rejects_unknown_dlm_id(
57 self,
58 tmp_path: Path,
59 monkeypatch: pytest.MonkeyPatch,
60 ) -> None:
61 handler_cls, handler, events, _rate_limiter, _pack_path = _build_test_handler(
62 tmp_path, path="/wrong?token=abc"
63 )
64 request_logs: list[tuple[str, str, str, str]] = []
65 monkeypatch.setattr(
66 peer_mod,
67 "_log_connection",
68 lambda ip, method, path, status: request_logs.append((ip, method, path, status)),
69 )
70
71 handler_cls.do_GET(handler)
72
73 assert events == [("error", "404", "unknown dlm_id", "")]
74 assert request_logs == [
75 ("127.0.0.1", "GET", "/wrong", "start"),
76 ("127.0.0.1", "GET", "/wrong", "404 unknown dlm_id"),
77 ]
78
79 def test_handler_rejects_missing_token(
80 self,
81 tmp_path: Path,
82 monkeypatch: pytest.MonkeyPatch,
83 ) -> None:
84 handler_cls, handler, events, _rate_limiter, _pack_path = _build_test_handler(
85 tmp_path, path="/01HZPEER"
86 )
87 request_logs: list[tuple[str, str, str, str]] = []
88 monkeypatch.setattr(
89 peer_mod,
90 "_log_connection",
91 lambda ip, method, path, status: request_logs.append((ip, method, path, status)),
92 )
93
94 handler_cls.do_GET(handler)
95
96 assert events == [("error", "401", "missing token", "")]
97 assert request_logs == [
98 ("127.0.0.1", "GET", "/01HZPEER", "start"),
99 ("127.0.0.1", "GET", "/01HZPEER", "401 missing token"),
100 ]
101
102 def test_handler_rejects_bad_token(
103 self,
104 tmp_path: Path,
105 monkeypatch: pytest.MonkeyPatch,
106 ) -> None:
107 handler_cls, handler, events, _rate_limiter, _pack_path = _build_test_handler(
108 tmp_path, path="/01HZPEER?token=bad"
109 )
110 request_logs: list[tuple[str, str, str, str]] = []
111 monkeypatch.setattr(
112 peer_mod,
113 "_log_connection",
114 lambda ip, method, path, status: request_logs.append((ip, method, path, status)),
115 )
116 monkeypatch.setattr(
117 peer_mod.PeerSession,
118 "verify_token",
119 lambda self, token: (_ for _ in ()).throw(PeerAuthError("bad token")),
120 )
121
122 handler_cls.do_GET(handler)
123
124 assert events == [("error", "403", "token rejected", "")]
125 assert request_logs == [
126 ("127.0.0.1", "GET", "/01HZPEER", "start"),
127 ("127.0.0.1", "GET", "/01HZPEER", "403 bad token"),
128 ]
129
130 def test_handler_rejects_rate_limited(
131 self,
132 tmp_path: Path,
133 monkeypatch: pytest.MonkeyPatch,
134 ) -> None:
135 handler_cls, handler, events, rate_limiter, _pack_path = _build_test_handler(
136 tmp_path, path="/01HZPEER?token=good"
137 )
138 request_logs: list[tuple[str, str, str, str]] = []
139 monkeypatch.setattr(
140 peer_mod,
141 "_log_connection",
142 lambda ip, method, path, status: request_logs.append((ip, method, path, status)),
143 )
144 monkeypatch.setattr(peer_mod.PeerSession, "verify_token", lambda self, token: None)
145 monkeypatch.setattr(
146 rate_limiter,
147 "check_and_acquire",
148 lambda: (_ for _ in ()).throw(RateLimitError("too many")),
149 )
150
151 handler_cls.do_GET(handler)
152
153 assert events == [("error", "429", "rate limited", "")]
154 assert request_logs == [
155 ("127.0.0.1", "GET", "/01HZPEER", "start"),
156 ("127.0.0.1", "GET", "/01HZPEER", "429 too many"),
157 ]
158
159 def test_handler_streams_pack_and_releases_limiter(
160 self,
161 tmp_path: Path,
162 monkeypatch: pytest.MonkeyPatch,
163 ) -> None:
164 handler_cls, handler, events, rate_limiter, pack_path = _build_test_handler(
165 tmp_path, path="/01HZPEER?token=good"
166 )
167 request_logs: list[tuple[str, str, str, str]] = []
168 monkeypatch.setattr(
169 peer_mod,
170 "_log_connection",
171 lambda ip, method, path, status: request_logs.append((ip, method, path, status)),
172 )
173 monkeypatch.setattr(peer_mod.PeerSession, "verify_token", lambda self, token: None)
174
175 handler_cls.do_GET(handler)
176
177 assert events == [("stream", str(pack_path), "", "")]
178 assert rate_limiter.active == 0
179 assert request_logs == [
180 ("127.0.0.1", "GET", "/01HZPEER", "start"),
181 ("127.0.0.1", "GET", "/01HZPEER", "200 complete"),
182 ]
183
184 def test_stream_pack_writes_headers_and_body(self, tmp_path: Path) -> None:
185 handler_cls, handler, _events, _rate_limiter, pack_path = _build_test_handler(
186 tmp_path, path="/ignored"
187 )
188 responses: list[tuple[str, str]] = []
189 body = BytesIO()
190 handler.wfile = body
191 handler.send_response = lambda status: responses.append(("status", str(status))) # type: ignore[attr-defined]
192 handler.send_header = lambda name, value: responses.append((name, value)) # type: ignore[attr-defined]
193 handler.end_headers = lambda: responses.append(("end", "")) # type: ignore[attr-defined]
194
195 handler_cls._stream_pack(handler, pack_path)
196
197 assert responses == [
198 ("status", "200"),
199 ("Content-Type", "application/octet-stream"),
200 ("Content-Length", str(len(b"peer-pack"))),
201 ("end", ""),
202 ]
203 assert body.getvalue() == b"peer-pack"
204
205
206 class TestPeerHelpers:
207 def test_log_connection_emits_metadata_only(self, caplog: pytest.LogCaptureFixture) -> None:
208 caplog.set_level("INFO")
209
210 _log_connection("127.0.0.1", "GET", "/01HZPEER", "200 complete")
211
212 assert "peer: GET /01HZPEER 200 complete from 127.0.0.1" in caplog.text
213
214 def test_pull_peer_reuses_url_sink(
215 self,
216 tmp_path: Path,
217 monkeypatch: pytest.MonkeyPatch,
218 ) -> None:
219 import dlm.share.url_sink as url_sink
220
221 out_path = tmp_path / "incoming.dlm.pack"
222 seen: dict[str, object] = {}
223
224 def _fake_pull_url(url: str, actual_out: Path, *, progress: object | None = None) -> int:
225 seen["url"] = url
226 seen["out"] = actual_out
227 seen["progress"] = progress
228 return 42
229
230 monkeypatch.setattr(url_sink, "pull_url", _fake_pull_url)
231
232 result = pull_peer("host:7337/01HZPEER?token=abc", out_path, progress=None)
233
234 assert result == 42
235 assert seen == {
236 "url": "http://host:7337/01HZPEER?token=abc",
237 "out": out_path,
238 "progress": None,
239 }
240
241
242 class TestServeHandle:
243 def test_peer_url_uses_bind_host_for_loopback(self) -> None:
244 handle = ServeHandle(
245 session=SimpleNamespace(dlm_id="01HZPEER"),
246 bind_host="127.0.0.1",
247 port=7337,
248 token="abc",
249 _server=SimpleNamespace(),
250 )
251
252 assert handle.peer_url == "peer://127.0.0.1:7337/01HZPEER?token=abc"
253
254 def test_peer_url_detects_lan_ip_for_public_bind(self, monkeypatch: pytest.MonkeyPatch) -> None:
255 handle = ServeHandle(
256 session=SimpleNamespace(dlm_id="01HZPEER"),
257 bind_host="0.0.0.0",
258 port=7337,
259 token="abc",
260 _server=SimpleNamespace(),
261 )
262 monkeypatch.setattr(peer_mod, "_detect_lan_ip", lambda: "192.168.1.9")
263
264 assert handle.peer_url == "peer://192.168.1.9:7337/01HZPEER?token=abc"
265
266 def test_wait_shutdown_stops_server_cleanly(self) -> None:
267 calls: list[str] = []
268 server = SimpleNamespace(
269 serve_forever=lambda: calls.append("serve_forever"),
270 shutdown=lambda: calls.append("shutdown"),
271 server_close=lambda: calls.append("server_close"),
272 )
273 handle = ServeHandle(
274 session=SimpleNamespace(dlm_id="01HZPEER"),
275 bind_host="127.0.0.1",
276 port=7337,
277 token="abc",
278 _server=server,
279 )
280
281 handle.wait_shutdown()
282
283 assert calls == ["serve_forever", "shutdown", "server_close"]
284
285 def test_wait_shutdown_handles_keyboard_interrupt(
286 self, caplog: pytest.LogCaptureFixture
287 ) -> None:
288 calls: list[str] = []
289
290 def _serve_forever() -> None:
291 calls.append("serve_forever")
292 raise KeyboardInterrupt
293
294 server = SimpleNamespace(
295 serve_forever=_serve_forever,
296 shutdown=lambda: calls.append("shutdown"),
297 server_close=lambda: calls.append("server_close"),
298 )
299 handle = ServeHandle(
300 session=SimpleNamespace(dlm_id="01HZPEER"),
301 bind_host="127.0.0.1",
302 port=7337,
303 token="abc",
304 _server=server,
305 )
306 caplog.set_level("INFO")
307
308 handle.wait_shutdown()
309
310 assert calls == ["serve_forever", "shutdown", "server_close"]
311 assert "shutdown requested" in caplog.text
312
313
314 class TestServe:
315 def test_serve_builds_handle(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
316 pack_path = tmp_path / "bundle.dlm.pack"
317 pack_path.write_bytes(b"peer-pack")
318 handler_cls = type("FakeHandler", (), {})
319 server_calls: dict[str, object] = {}
320
321 class FakeSession:
322 dlm_id = "01HZPEER"
323
324 def issue_token(self) -> str:
325 return "issued-token"
326
327 class FakeServer:
328 def __init__(self, address: tuple[str, int], handler: type[object]) -> None:
329 server_calls["address"] = address
330 server_calls["handler"] = handler
331
332 monkeypatch.setattr(
333 peer_mod, "new_session", lambda dlm_id, token_ttl_seconds: FakeSession()
334 )
335 monkeypatch.setattr(
336 peer_mod, "build_handler", lambda session, actual_pack, limiter: handler_cls
337 )
338 monkeypatch.setattr(peer_mod, "resolve_bind", lambda opts: "127.0.0.1")
339 monkeypatch.setattr(peer_mod.http.server, "ThreadingHTTPServer", FakeServer)
340
341 handle = serve("01HZPEER", pack_path, ServeOptions(port=8123))
342
343 assert handle.session.dlm_id == "01HZPEER"
344 assert handle.bind_host == "127.0.0.1"
345 assert handle.port == 8123
346 assert handle.token == "issued-token"
347 assert server_calls == {
348 "address": ("127.0.0.1", 8123),
349 "handler": handler_cls,
350 }
351
352
353 class TestDetectLanIp:
354 def test_detect_lan_ip_returns_socket_address(self, monkeypatch: pytest.MonkeyPatch) -> None:
355 class FakeSocket:
356 def settimeout(self, value: float) -> None:
357 assert value == 0.1
358
359 def connect(self, target: tuple[str, int]) -> None:
360 assert target == ("10.254.254.254", 1)
361
362 def getsockname(self) -> tuple[str, int]:
363 return ("192.168.1.7", 9999)
364
365 def __enter__(self) -> FakeSocket:
366 return self
367
368 def __exit__(self, exc_type: object, exc: object, tb: object) -> None:
369 return None
370
371 monkeypatch.setattr(socket, "socket", lambda *args, **kwargs: FakeSocket())
372
373 assert _detect_lan_ip() == "192.168.1.7"
374
375 def test_detect_lan_ip_returns_placeholder_on_error(
376 self,
377 monkeypatch: pytest.MonkeyPatch,
378 ) -> None:
379 class FakeSocket:
380 def __enter__(self) -> FakeSocket:
381 raise OSError("no route")
382
383 def __exit__(self, exc_type: object, exc: object, tb: object) -> None:
384 return None
385
386 monkeypatch.setattr(socket, "socket", lambda *args, **kwargs: FakeSocket())
387
388 assert _detect_lan_ip() == "<lan-ip>"