forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
dist_utils.py
139 lines (110 loc) · 4.95 KB
/
dist_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from __future__ import absolute_import, division, print_function, unicode_literals
import threading
from functools import partial, wraps
import torch.distributed as dist
import torch.distributed.rpc as rpc
if not dist.is_available():
print("c10d not available, skipping tests")
sys.exit(0)
class TestConfig:
__slots__ = ["rpc_backend_name", "build_rpc_agent_options"]
def __init__(self, *args, **kwargs):
assert len(args) == 0, "TestConfig only takes kwargs."
for k, v in kwargs.items():
setattr(self, k, v)
TEST_CONFIG = TestConfig()
INIT_METHOD_TEMPLATE = "file://{file_name}"
MASTER_RANK = 0
_ALL_NODE_NAMES = set()
_DONE_NODE_NAMES = set()
_TERMINATION_SIGNAL = threading.Event()
def on_master_follower_report_done(worker_name):
assert (
worker_name in _ALL_NODE_NAMES
), "{worker_name} is not expected by master.".format(worker_name=worker_name)
assert (
worker_name not in _DONE_NODE_NAMES
), "{worker_name} report done twice.".format(worker_name=worker_name)
_DONE_NODE_NAMES.add(worker_name)
if _ALL_NODE_NAMES != _DONE_NODE_NAMES:
return
set_termination_signal()
def set_termination_signal():
assert not _TERMINATION_SIGNAL.is_set(), "Termination signal got set twice."
_TERMINATION_SIGNAL.set()
def dist_init(old_test_method=None, setup_rpc=True, clean_shutdown=True):
"""
We use this decorator for setting up and tearing down state since
MultiProcessTestCase runs each `test*` method in a separate process and
each process just runs the `test*` method without actually calling
'setUp' and 'tearDown' methods of unittest.
"""
# If we use dist_init without arguments (ex: @dist_init), old_test_method is
# appropriately set and we return the wrapper appropriately. On the other
# hand if dist_init has arguments (ex: @dist_init(clean_shutdown=False)),
# old_test_method is None and we return a functools.partial which is the real
# decorator that is used and as a result we recursively call dist_init with
# old_test_method and the rest of the arguments appropriately set.
if old_test_method is None:
return partial(
dist_init,
setup_rpc=setup_rpc,
clean_shutdown=clean_shutdown,
)
@wraps(old_test_method)
def new_test_method(self, *arg, **kwargs):
self.worker_id = self.rank
if setup_rpc:
global _ALL_NODE_NAMES
_ALL_NODE_NAMES = {
"worker{}".format(rank) for rank in range(self.world_size)
}
rpc.init_rpc(
name="worker%d" % self.rank,
backend=self.rpc_backend,
init_method=self.init_method,
rank=self.rank,
world_size=self.world_size,
rpc_agent_options=self.rpc_agent_options,
)
return_value = old_test_method(self, *arg, **kwargs)
if setup_rpc:
if clean_shutdown:
# Follower reports done.
if self.rank == MASTER_RANK:
on_master_follower_report_done("worker{}".format(MASTER_RANK))
else:
rpc.rpc_async(
"worker{}".format(MASTER_RANK),
on_master_follower_report_done,
args=("worker{}".format(self.rank),),
)
# Master waits for followers to report done.
# Follower waits for master's termination command.
_TERMINATION_SIGNAL.wait()
if self.rank == MASTER_RANK:
# Master sends termination command.
futs = []
for dst_rank in range(self.world_size):
# torch.distributed.rpc module does not support sending to self.
if dst_rank == MASTER_RANK:
continue
dst_name = "worker{}".format(dst_rank)
fut = rpc.rpc_async(dst_name, set_termination_signal, args=())
futs.append(fut)
for fut in futs:
assert fut.wait() is None, "Sending termination signal failed."
# Close RPC. Need to do this even if we don't have a clean shutdown
# since we need to shutdown the RPC agent. If we don't shutdown the
# RPC agent, tests would fail since RPC agent threads, locks and
# condition variables are not properly terminated.
rpc.wait_all_workers()
return return_value
return new_test_method
# Set PROCESS_GROUP as the default RPC backend.
TEST_CONFIG.rpc_backend_name = "PROCESS_GROUP"
TEST_CONFIG.build_rpc_agent_options = lambda test_object: rpc.backend_registry.construct_rpc_agent_options(
test_object.rpc_backend,
# Use enough 'num_send_recv_threads' until we fix https://github.com/pytorch/pytorch/issues/26359
num_send_recv_threads=16,
)