-
Notifications
You must be signed in to change notification settings - Fork 1
/
dist_utils.py
64 lines (48 loc) · 2.3 KB
/
dist_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import os
from datetime import timedelta
import torch
def init_distributed_env(args):
# Initialize the distributed environment
args.world_size = int(os.environ.get('WORLD_SIZE', os.environ.get('SLURM_NTASKS', 1)))
args.distributed = args.world_size > 1
args.rank = int(os.environ.get('RANK', os.environ.get('SLURM_PROCID', 0)))
args.local_rank = int(os.environ.get('LOCAL_RANK', os.environ.get('SLURM_LOCALID', 0)))
args.gpu = args.local_rank
if args.distributed:
torch.cuda.set_device(args.gpu)
torch.distributed.init_process_group(backend="nccl", init_method="env://", timeout=timedelta(hours=1))
obtained_world_size = torch.distributed.get_world_size()
assert obtained_world_size == args.world_size, f"{obtained_world_size} != {args.world_size}"
print(f"Initializing the environment with {args.world_size} processes / Process rank: {args.rank} / Local rank: {args.local_rank}")
setup_for_distributed(args.local_rank == 0) # print via one process per node
args.effective_batch_size = args.batch_size * args.world_size
print(f"# processes: {args.world_size} / batch size: {args.batch_size} / effective batch size: {args.effective_batch_size}")
def is_main_proc(local_rank=None, shared_fs=True):
assert shared_fs or local_rank is not None
main_proc = not torch.distributed.is_initialized() or (torch.distributed.get_rank() == 0 if shared_fs else local_rank == 0)
return main_proc
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop('force', False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def get_world_size():
return torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
def wait_for_other_procs():
if torch.distributed.is_initialized():
torch.distributed.barrier()
def reduce_tensor(tensor, average=False):
world_size = get_world_size()
if world_size == 1:
return tensor
rt = tensor.clone()
torch.distributed.all_reduce(rt, op=torch.distributed.ReduceOp.SUM)
if average:
rt /= world_size
return rt