| 1 |
"""Per-rank I/O helpers — `master_only`, `barrier`, `gather_metrics`.""" |
| 2 |
|
| 3 |
from __future__ import annotations |
| 4 |
|
| 5 |
from types import SimpleNamespace |
| 6 |
from unittest.mock import MagicMock |
| 7 |
|
| 8 |
import pytest |
| 9 |
|
| 10 |
from dlm.train.distributed.rank_io import ( |
| 11 |
barrier, |
| 12 |
gather_metrics, |
| 13 |
is_main_process, |
| 14 |
master_only, |
| 15 |
) |
| 16 |
|
| 17 |
|
| 18 |
class TestIsMainProcess: |
| 19 |
def test_none_accelerator_returns_true(self) -> None: |
| 20 |
assert is_main_process(None) is True |
| 21 |
|
| 22 |
def test_rank_0_returns_true(self) -> None: |
| 23 |
assert is_main_process(SimpleNamespace(is_main_process=True)) is True |
| 24 |
|
| 25 |
def test_non_zero_rank_returns_false(self) -> None: |
| 26 |
assert is_main_process(SimpleNamespace(is_main_process=False)) is False |
| 27 |
|
| 28 |
def test_missing_attribute_defaults_to_true(self) -> None: |
| 29 |
"""Missing `is_main_process` attr → assume single-process → True.""" |
| 30 |
assert is_main_process(SimpleNamespace()) is True |
| 31 |
|
| 32 |
|
| 33 |
class TestMasterOnly: |
| 34 |
def test_rank_0_calls_function(self) -> None: |
| 35 |
acc = SimpleNamespace(is_main_process=True) |
| 36 |
calls: list[int] = [] |
| 37 |
|
| 38 |
@master_only |
| 39 |
def write(_acc: object, value: int) -> int: |
| 40 |
calls.append(value) |
| 41 |
return value |
| 42 |
|
| 43 |
result = write(acc, 42) |
| 44 |
assert result == 42 |
| 45 |
assert calls == [42] |
| 46 |
|
| 47 |
def test_non_zero_rank_skips(self) -> None: |
| 48 |
acc = SimpleNamespace(is_main_process=False) |
| 49 |
calls: list[int] = [] |
| 50 |
|
| 51 |
@master_only |
| 52 |
def write(_acc: object, value: int) -> int: |
| 53 |
calls.append(value) |
| 54 |
return value |
| 55 |
|
| 56 |
result = write(acc, 42) |
| 57 |
assert result is None |
| 58 |
assert calls == [] |
| 59 |
|
| 60 |
def test_single_process_calls_function(self) -> None: |
| 61 |
calls: list[int] = [] |
| 62 |
|
| 63 |
@master_only |
| 64 |
def write(_acc: object, value: int) -> int: |
| 65 |
calls.append(value) |
| 66 |
return value |
| 67 |
|
| 68 |
assert write(None, 7) == 7 |
| 69 |
assert calls == [7] |
| 70 |
|
| 71 |
|
| 72 |
class TestBarrier: |
| 73 |
def test_none_accelerator_is_noop(self) -> None: |
| 74 |
# Should not raise, does nothing. |
| 75 |
barrier(None) |
| 76 |
|
| 77 |
def test_calls_wait_for_everyone(self) -> None: |
| 78 |
acc = MagicMock() |
| 79 |
acc.wait_for_everyone = MagicMock() |
| 80 |
barrier(acc) |
| 81 |
acc.wait_for_everyone.assert_called_once() |
| 82 |
|
| 83 |
def test_tolerates_missing_method(self) -> None: |
| 84 |
# Stub accelerator with no `wait_for_everyone` — old versions |
| 85 |
# or mocks. Must not raise. |
| 86 |
barrier(SimpleNamespace()) |
| 87 |
|
| 88 |
|
| 89 |
class TestGatherMetrics: |
| 90 |
def test_none_accelerator_passes_through(self) -> None: |
| 91 |
out = gather_metrics(None, {"loss": 2.5, "ppl": 12.0}) |
| 92 |
assert out == {"loss": 2.5, "ppl": 12.0} |
| 93 |
|
| 94 |
def test_missing_gather_method_passes_through(self) -> None: |
| 95 |
"""Accelerator without `gather_for_metrics` degrades gracefully.""" |
| 96 |
out = gather_metrics(SimpleNamespace(), {"loss": 1.0}) |
| 97 |
assert out == {"loss": 1.0} |
| 98 |
|
| 99 |
def test_gather_averaged_across_ranks(self) -> None: |
| 100 |
"""Simulate 2 ranks — gather returns a stacked tensor; take mean.""" |
| 101 |
import torch |
| 102 |
|
| 103 |
def _gather(tensor: torch.Tensor) -> torch.Tensor: |
| 104 |
# Simulate two-rank gather: stack the input with a twin. |
| 105 |
return torch.stack([tensor, tensor + 1.0]) |
| 106 |
|
| 107 |
acc = SimpleNamespace(gather_for_metrics=_gather, is_main_process=True) |
| 108 |
out = gather_metrics(acc, {"loss": 2.0}) |
| 109 |
# mean of [2.0, 3.0] = 2.5 |
| 110 |
assert out["loss"] == pytest.approx(2.5) |
| 111 |
|
| 112 |
def test_gather_none_falls_back_to_original_value(self) -> None: |
| 113 |
acc = SimpleNamespace(gather_for_metrics=lambda tensor: None, is_main_process=True) |
| 114 |
out = gather_metrics(acc, {"loss": 2.0}) |
| 115 |
assert out == {"loss": 2.0} |