Python · 3699 bytes Raw Blame History
1 """Per-rank I/O helpers — `master_only`, `barrier`, `gather_metrics`."""
2
3 from __future__ import annotations
4
5 from types import SimpleNamespace
6 from unittest.mock import MagicMock
7
8 import pytest
9
10 from dlm.train.distributed.rank_io import (
11 barrier,
12 gather_metrics,
13 is_main_process,
14 master_only,
15 )
16
17
18 class TestIsMainProcess:
19 def test_none_accelerator_returns_true(self) -> None:
20 assert is_main_process(None) is True
21
22 def test_rank_0_returns_true(self) -> None:
23 assert is_main_process(SimpleNamespace(is_main_process=True)) is True
24
25 def test_non_zero_rank_returns_false(self) -> None:
26 assert is_main_process(SimpleNamespace(is_main_process=False)) is False
27
28 def test_missing_attribute_defaults_to_true(self) -> None:
29 """Missing `is_main_process` attr → assume single-process → True."""
30 assert is_main_process(SimpleNamespace()) is True
31
32
33 class TestMasterOnly:
34 def test_rank_0_calls_function(self) -> None:
35 acc = SimpleNamespace(is_main_process=True)
36 calls: list[int] = []
37
38 @master_only
39 def write(_acc: object, value: int) -> int:
40 calls.append(value)
41 return value
42
43 result = write(acc, 42)
44 assert result == 42
45 assert calls == [42]
46
47 def test_non_zero_rank_skips(self) -> None:
48 acc = SimpleNamespace(is_main_process=False)
49 calls: list[int] = []
50
51 @master_only
52 def write(_acc: object, value: int) -> int:
53 calls.append(value)
54 return value
55
56 result = write(acc, 42)
57 assert result is None
58 assert calls == []
59
60 def test_single_process_calls_function(self) -> None:
61 calls: list[int] = []
62
63 @master_only
64 def write(_acc: object, value: int) -> int:
65 calls.append(value)
66 return value
67
68 assert write(None, 7) == 7
69 assert calls == [7]
70
71
72 class TestBarrier:
73 def test_none_accelerator_is_noop(self) -> None:
74 # Should not raise, does nothing.
75 barrier(None)
76
77 def test_calls_wait_for_everyone(self) -> None:
78 acc = MagicMock()
79 acc.wait_for_everyone = MagicMock()
80 barrier(acc)
81 acc.wait_for_everyone.assert_called_once()
82
83 def test_tolerates_missing_method(self) -> None:
84 # Stub accelerator with no `wait_for_everyone` — old versions
85 # or mocks. Must not raise.
86 barrier(SimpleNamespace())
87
88
89 class TestGatherMetrics:
90 def test_none_accelerator_passes_through(self) -> None:
91 out = gather_metrics(None, {"loss": 2.5, "ppl": 12.0})
92 assert out == {"loss": 2.5, "ppl": 12.0}
93
94 def test_missing_gather_method_passes_through(self) -> None:
95 """Accelerator without `gather_for_metrics` degrades gracefully."""
96 out = gather_metrics(SimpleNamespace(), {"loss": 1.0})
97 assert out == {"loss": 1.0}
98
99 def test_gather_averaged_across_ranks(self) -> None:
100 """Simulate 2 ranks — gather returns a stacked tensor; take mean."""
101 import torch
102
103 def _gather(tensor: torch.Tensor) -> torch.Tensor:
104 # Simulate two-rank gather: stack the input with a twin.
105 return torch.stack([tensor, tensor + 1.0])
106
107 acc = SimpleNamespace(gather_for_metrics=_gather, is_main_process=True)
108 out = gather_metrics(acc, {"loss": 2.0})
109 # mean of [2.0, 3.0] = 2.5
110 assert out["loss"] == pytest.approx(2.5)
111
112 def test_gather_none_falls_back_to_original_value(self) -> None:
113 acc = SimpleNamespace(gather_for_metrics=lambda tensor: None, is_main_process=True)
114 out = gather_metrics(acc, {"loss": 2.0})
115 assert out == {"loss": 2.0}