Skip to content

Commit

Permalink
black-format
Browse files Browse the repository at this point in the history
  • Loading branch information
bstriner committed May 20, 2022
1 parent 09e92dd commit b660eff
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 14 deletions.
7 changes: 2 additions & 5 deletions src/sagemaker_training/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ async def run_async(cmd, processes_per_host, env, cwd, stderr, **kwargs):
cmd, env=env, cwd=cwd, stdout=PIPE, stderr=stderr, **kwargs
)

with capture_signal(signal.SIGTERM, lambda signalnum, *_: proc.send_signal(signalnum)):
with capture_signal(signal.SIGTERM, lambda signalnum, *_: proc.send_signal(signalnum)):
output = await asyncio.gather(
watch(proc.stdout, processes_per_host), watch(proc.stderr, processes_per_host)
)
Expand Down Expand Up @@ -219,10 +219,7 @@ def check_error(cmd, error_class, processes_per_host, cwd=None, capture_error=Tr
process = subprocess.Popen(
cmd, env=os.environ, cwd=cwd or environment.code_dir, stderr=stderr, **kwargs
)
with capture_signal(
signal.SIGTERM,
lambda signalnum, *_: process.send_signal(signalnum)
):
with capture_signal(signal.SIGTERM, lambda signalnum, *_: process.send_signal(signalnum)):
return_code = process.wait()
if return_code:
extra_info = None
Expand Down
12 changes: 3 additions & 9 deletions test/unit/test_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,13 +183,10 @@ def test_run_python(log, async_shell, async_gather, entry_point_type_script, eve
def _sleep_subprocess(capture_error):
with pytest.raises(errors.ExecuteUserScriptError) as error:
process.check_error(
[
sys.executable,
os.path.abspath(os.path.join(__file__, "../_test_process_helper.py"))
],
[sys.executable, os.path.abspath(os.path.join(__file__, "../_test_process_helper.py"))],
errors.ExecuteUserScriptError,
1,
capture_error=capture_error
capture_error=capture_error,
)
assert int(error.value.return_code) == 21
exit(42)
Expand All @@ -198,10 +195,7 @@ def _sleep_subprocess(capture_error):
@pytest.mark.skipif(sys.version_info != (3, 7), reason="requires python3.7")
@pytest.mark.parametrize("capture_error", [True, False])
def test_check_error_signal(capture_error):
proc = multiprocessing.Process(
target=_sleep_subprocess,
args=(capture_error,)
)
proc = multiprocessing.Process(target=_sleep_subprocess, args=(capture_error,))
proc.start()
time.sleep(1)
os.kill(proc.pid, signal.SIGTERM)
Expand Down

0 comments on commit b660eff

Please sign in to comment.