-
Notifications
You must be signed in to change notification settings - Fork 7
/
torch-distributed-demo.py
81 lines (64 loc) · 2.04 KB
/
torch-distributed-demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
"""
run::
python -m torch.distributed.run --standalone --nnodes 1 --nproc-per-node=2 torch-distributed-demo.py
https://pytorch.org/docs/stable/notes/ddp.html
https://github.com/rwth-i6/returnn/issues/1469
https://github.com/pytorch/pytorch/issues/114765
"""
import os
import sys
import time
import subprocess as sp
import torch
from torch import nn
from torch import optim
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
def _debug_mem(msg):
if local_rank == 1:
print(f"*** {msg} {{")
print("reserved mem GPU0:", torch.cuda.memory_reserved(0))
sp.call(
f"(nvidia-smi; echo '*** {msg} -- {os.getpid()} }}'; ) | grep {os.getpid()}",
shell=True,
stdout=sys.stdout,
stderr=sys.stdout,
)
sys.stdout.flush()
dist.init_process_group(backend=None) # nccl + gloo
local_rank = int(os.environ["LOCAL_RANK"])
local_size = int(os.environ["LOCAL_WORLD_SIZE"])
dev = torch.device(f"cuda:{local_rank}")
print(f"Start running torch distributed training on local rank {local_rank}/{local_size}.")
_debug_mem("start")
# torch.cuda.set_device(dev) # -- should not be necessary, but currently is
class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(10, 10)
def forward(self, x):
return self.fc(x)
model = Model()
model.to(dev)
_debug_mem("after model init")
ddp_model = DistributedDataParallel(model, device_ids=[local_rank])
_debug_mem("after DDP wrapping")
# define loss function and optimizer
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
_debug_mem("after optimizer init")
step = 0
while True:
# forward pass
outputs = ddp_model(torch.randn(20, 10, device=dev))
labels = torch.randn(20, 10, device=dev)
# backward pass
loss_fn(outputs, labels).backward()
# update parameters
optimizer.step()
print(f"[{local_rank}] step {step}")
_debug_mem(f"step {step}")
if step >= 3:
break
time.sleep(0.5)
step += 1