Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 40 additions & 2 deletions deepspeed/launcher/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import os
import json
import base64
import time
import signal
from collections import defaultdict
from argparse import ArgumentParser, REMAINDER

Expand Down Expand Up @@ -122,11 +124,47 @@ def main():
args.training_script,
"--local_rank={}".format(local_rank)
] + args.training_script_args

sig_names = {2: "SIGINT", 15: "SIGTERM"}
last_return_code = None

def sigkill_handler(signum, frame):
for process in processes:
print(f"Killing subprocess {process.pid}")
try:
process.kill()
except Exception as e:
pass
if last_return_code is not None:
raise subprocess.CalledProcessError(returncode=last_return_code, cmd=cmd)
if signum in sig_names:
print(f"Main process received {sig_names[signum]}, exiting")
sys.exit(1)

# pass SIGINT/SIGTERM to children if the parent is being terminated
signal.signal(signal.SIGINT, sigkill_handler)
signal.signal(signal.SIGTERM, sigkill_handler)

process = subprocess.Popen(cmd, env=current_env)
processes.append(process)

for process in processes:
process.wait()
alive_processes = set(processes)
while len(alive_processes):
finished_processes = []
for process in alive_processes:
if process.poll() is None:
# the process is still running
continue
else:
if process.returncode != 0:
last_return_code = process.returncode # for sigkill_handler
sigkill_handler(signal.SIGTERM, None) # not coming back
else:
# exited cleanly
finished_processes.append(process)
alive_processes = set(alive_processes) - set(finished_processes)

time.sleep(1)


if __name__ == "__main__":
Expand Down