-
Notifications
You must be signed in to change notification settings - Fork 17
/
train.py
executable file
·120 lines (97 loc) · 3.77 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#!/usr/bin/env python
"""
Main training workflow
"""
import configargparse
import os
import signal
import torch
import onmt.opts as opts
import onmt.utils.distributed
from onmt.utils.logging import logger
from onmt.train_single import main as single_main
def main(opt):
if opt.rnn_type == "SRU" and not opt.gpu_ranks:
raise AssertionError("Using SRU requires -gpu_ranks set.")
if opt.epochs:
raise AssertionError("-epochs is deprecated please use -train_steps.")
if opt.truncated_decoder > 0 and opt.accum_count > 1:
raise AssertionError("BPTT is not compatible with -accum > 1")
if opt.gpuid:
raise AssertionError("gpuid is deprecated \
see world_size and gpu_ranks")
nb_gpu = len(opt.gpu_ranks)
if opt.world_size > 1:
mp = torch.multiprocessing.get_context('spawn')
# Create a thread to listen for errors in the child processes.
error_queue = mp.SimpleQueue()
error_handler = ErrorHandler(error_queue)
# Train with multiprocessing.
procs = []
for device_id in range(nb_gpu):
procs.append(mp.Process(target=run, args=(
opt, device_id, error_queue, ), daemon=True))
procs[device_id].start()
logger.info(" Starting process pid: %d " % procs[device_id].pid)
error_handler.add_child(procs[device_id].pid)
for p in procs:
p.join()
elif nb_gpu == 1: # case 1 GPU only
single_main(opt, 0)
else: # case only CPU
single_main(opt, -1)
def run(opt, device_id, error_queue):
""" run process """
try:
gpu_rank = onmt.utils.distributed.multi_init(opt, device_id)
if gpu_rank != opt.gpu_ranks[device_id]:
raise AssertionError("An error occurred in \
Distributed initialization")
single_main(opt, device_id)
except KeyboardInterrupt:
pass # killed by parent, do nothing
except Exception:
# propagate exception to parent process, keeping original traceback
import traceback
error_queue.put((opt.gpu_ranks[device_id], traceback.format_exc()))
class ErrorHandler(object):
"""A class that listens for exceptions in children processes and propagates
the tracebacks to the parent process."""
def __init__(self, error_queue):
""" init error handler """
import signal
import threading
self.error_queue = error_queue
self.children_pids = []
self.error_thread = threading.Thread(
target=self.error_listener, daemon=True)
self.error_thread.start()
signal.signal(signal.SIGUSR1, self.signal_handler)
def add_child(self, pid):
""" error handler """
self.children_pids.append(pid)
def error_listener(self):
""" error listener """
(rank, original_trace) = self.error_queue.get()
self.error_queue.put((rank, original_trace))
os.kill(os.getpid(), signal.SIGUSR1)
def signal_handler(self, signalnum, stackframe):
""" signal handler """
for pid in self.children_pids:
os.kill(pid, signal.SIGINT) # kill children processes
(rank, original_trace) = self.error_queue.get()
msg = """\n\n-- Tracebacks above this line can probably
be ignored --\n\n"""
msg += original_trace
raise Exception(msg)
if __name__ == "__main__":
parser = configargparse.ArgumentParser(
description='train.py',
config_file_parser_class=configargparse.YAMLConfigFileParser,
formatter_class=configargparse.ArgumentDefaultsHelpFormatter)
opts.config_opts(parser)
opts.add_md_help_argument(parser)
opts.model_opts(parser)
opts.train_opts(parser)
opt = parser.parse_args()
main(opt)