-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathmain.py
48 lines (38 loc) · 1.45 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import os
import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.multiprocessing as mp
from loguru import logger
from opts import get_opts
from pl_trainer import ACPL_Trainer
def main():
os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning"
args = get_opts()
cudnn.benchmark = True
args.distributed = args.world_size > 0 or args.multiprocessing_distributed
ngpus_per_node = torch.cuda.device_count()
args.world_size = ngpus_per_node * args.world_size
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
def main_worker(gpu, ngpus_per_node, args):
args.gpu = gpu
if args.gpu is not None:
if args.multiprocessing_distributed:
logger.info(f"Use GPU: {args.gpu} for training", enqueue=True)
if args.distributed:
if args.dist_url == "env://" and args.rank == -1:
args.rank = int(os.environ["RANK"])
if args.multiprocessing_distributed:
# For multiprocessing distributed training, rank needs to be the
# global rank among all the processes
args.rank = args.rank * ngpus_per_node + gpu
dist.init_process_group(
backend=args.dist_backend,
init_method=args.dist_url,
world_size=args.world_size,
rank=args.rank,
)
trainer = ACPL_Trainer(args)
trainer.pipeline_master(args)
if __name__ == "__main__":
main()