diff --git a/python/oneflow/distributed/launch.py b/python/oneflow/distributed/launch.py index f6992ca96d0..b058bfaa6fc 100644 --- a/python/oneflow/distributed/launch.py +++ b/python/oneflow/distributed/launch.py @@ -156,16 +156,23 @@ def main(): sig_names = {2: "SIGINT", 15: "SIGTERM"} last_return_code = None + # set killing flag to make sure killing signal only executed once + kill_flag = True + def sigkill_handler(signum, frame): + nonlocal kill_flag + if not kill_flag: + return for process in processes: print(f"Killing subprocess {process.pid}") - try: - # Note: use os.kill or process.kill() may only kill current process - # use killpg will kill(use signal) this process and all sub-processes - # if orphan sub-processes still exist, use signal.SIGKILL instead. - os.killpg(os.getpgid(process.pid), signal.SIGTERM) - except Exception: - pass + kill_flag = False + try: + # Note: use os.kill or process.kill() may only kill current process + # use killpg will kill(use signal) this process and all sub-processes + # if orphan sub-processes still exist, use signal.SIGKILL instead. + os.killpg(os.getpid(), signal.SIGTERM) + except Exception: + pass if last_return_code is not None: raise subprocess.CalledProcessError( returncode=last_return_code, cmd=cmd