forked from sczzz3/EHRDiff
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
116 lines (96 loc) · 3.92 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import logging
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import sys
import argparse
from omegaconf import OmegaConf
from utils.util import make_dir
try:
mp.set_start_method('spawn')
except RuntimeError:
pass
def run_main(config):
###
os.environ['CUDA_VISIBLE_DEVICES'] = str(config.setup.CUDA_DEVICES)
###
processes = []
for rank in range(config.setup.n_gpus_per_node):
config.setup.local_rank = rank
config.setup.global_rank = rank + \
config.setup.node_rank * config.setup.n_gpus_per_node
config.setup.global_size = config.setup.n_nodes * config.setup.n_gpus_per_node
# print('Node rank %d, local proc %d, global proc %d' % (
# config.setup.node_rank, config.setup.local_rank, config.setup.global_rank))
p = mp.Process(target=setup, args=(config, main))
p.start()
processes.append(p)
for p in processes:
p.join()
def setup(config, fn):
os.environ['MASTER_ADDR'] = config.setup.master_address
os.environ['MASTER_PORT'] = '%d' % config.setup.master_port
os.environ['OMP_NUM_THREADS'] = '%d' % config.setup.omp_n_threads
torch.cuda.set_device(config.setup.local_rank)
dist.init_process_group(backend='nccl',
init_method='env://',
rank=config.setup.global_rank,
world_size=config.setup.global_size)
fn(config)
dist.barrier()
dist.destroy_process_group()
def set_logger(gfile_stream):
handler = logging.StreamHandler(gfile_stream)
formatter = logging.Formatter(
'%(levelname)s - %(filename)s - %(asctime)s - %(message)s')
handler.setFormatter(formatter)
logger = logging.getLogger()
logger.addHandler(handler)
logger.setLevel('INFO')
def main(config):
workdir = os.path.join(config.setup.root_folder, config.setup.workdir)
if config.setup.mode == 'train':
if config.setup.global_rank == 0:
if config.setup.mode == 'train':
make_dir(workdir)
gfile_stream = open(os.path.join(workdir, 'stdout.txt'), 'w')
else:
if not os.path.exists(workdir):
raise ValueError('Working directoy does not exist.')
gfile_stream = open(os.path.join(workdir, 'stdout.txt'), 'a')
set_logger(gfile_stream)
logging.info(config)
if config.setup.runner == 'train_dpdm_base':
from runners import train_dpdm_base
train_dpdm_base.training(config, workdir, config.setup.mode)
else:
raise NotImplementedError('Runner is not yet implemented.')
elif config.setup.mode == 'eval':
if config.setup.global_rank == 0:
make_dir(workdir)
gfile_stream = open(os.path.join(workdir, 'stdout.txt'), 'w')
set_logger(gfile_stream)
logging.info(config)
if config.setup.runner == 'generate_base':
from runners import generate_base
generate_base.evaluation(config, workdir)
else:
raise NotImplementedError('Runner is not yet implemented.')
if __name__ == '__main__':
sys.path.append(os.getcwd())
parser = argparse.ArgumentParser()
parser.add_argument('--config', nargs="*", default=list(), required=True)
parser.add_argument('--workdir', required=True)
parser.add_argument('--mode', choices=['train', 'eval'], required=True)
parser.add_argument('--root_folder', default='.')
opt, unknown = parser.parse_known_args()
configs = [OmegaConf.load(cfg) for cfg in opt.config]
cli = OmegaConf.from_dotlist(unknown)
config = OmegaConf.merge(*configs, cli)
config.setup.workdir = opt.workdir
config.setup.mode = opt.mode
config.setup.root_folder = opt.root_folder
if config.setup.n_nodes > 1:
raise NotImplementedError('This has not been tested.')
run_main(config)