-
Notifications
You must be signed in to change notification settings - Fork 5
/
distributed_train.py
executable file
·83 lines (68 loc) · 3.47 KB
/
distributed_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# *****************************************************************************
# Adapted from https://github.com/NVIDIA/waveglow/blob/master/distributed.py
# *****************************************************************************
# *****************************************************************************
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of the NVIDIA CORPORATION nor the
# names of its contributors may be used to endorse or promote products
# derived from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# *****************************************************************************\
import os
import sys
import time
import subprocess
import argparse
import warnings
warnings.filterwarnings("ignore")
import torch
from distributed_util import *
def main(config, stdout_dir, args_str):
args_list = ['train.py']
args_list += args_str.split(' ') if len(args_str) > 0 else []
args_list.append('--config={}'.format(config))
num_gpus = torch.cuda.device_count()
args_list.append('--num_gpus={}'.format(num_gpus))
args_list.append("--group_name=group_{}".format(time.strftime("%Y_%m_%d-%H%M%S")))
if not os.path.isdir(stdout_dir):
os.makedirs(stdout_dir)
os.chmod(stdout_dir, 0o775)
workers = []
for i in range(num_gpus):
args_list[-2] = '--rank={}'.format(i)
stdout = None if i == 0 else open(
os.path.join(stdout_dir, "GPU_{}.log".format(i)), "w")
print(args_list)
p = subprocess.Popen([str(sys.executable)]+args_list, stdout=stdout)
workers.append(p)
for p in workers:
p.wait()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str,
help='JSON file for configuration')
parser.add_argument('-s', '--stdout_dir', type=str, default="exp/",
help='directory to save stoud logs')
parser.add_argument('-a', '--args_str', type=str, default='',
help='double quoted string with space separated key value pairs')
args = parser.parse_args()
main(args.config, args.stdout_dir, args.args_str)