Skip to content

Commit 62240ab

Browse files
committed
add extra process group protection
Signed-off-by: alec-flowers <aflowers@nvidia.com>
1 parent 4e39826 commit 62240ab

File tree

1 file changed

+68
-9
lines changed

1 file changed

+68
-9
lines changed

tests/utils/managed_process.py

Lines changed: 68 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import logging
1818
import os
1919
import shutil
20+
import signal
2021
import socket
2122
import subprocess
2223
import time
@@ -82,6 +83,10 @@ class ManagedProcess:
8283
straggler_commands: List[str] = field(default_factory=list)
8384
log_dir: str = os.getcwd()
8485

86+
# Ensure attributes exist even if startup fails early
87+
proc: Optional[subprocess.Popen] = None
88+
_pgid: Optional[int] = None
89+
8590
_logger = logging.getLogger()
8691
_command_name = None
8792
_log_path = None
@@ -107,20 +112,30 @@ def __enter__(self):
107112

108113
return self
109114

110-
except Exception as e:
111-
self.__exit__(None, None, None)
112-
raise e
115+
except Exception:
116+
try:
117+
self.__exit__(None, None, None)
118+
except Exception as cleanup_err:
119+
self._logger.warning(
120+
"Error during cleanup in __enter__: %s", cleanup_err
121+
)
122+
raise
113123

114124
def __exit__(self, exc_type, exc_val, exc_tb):
125+
self._terminate_process_group()
126+
115127
process_list = [self.proc, self._tee_proc, self._sed_proc]
116128
for process in process_list:
117129
if process:
118-
if process.stdout:
119-
process.stdout.close()
120-
if process.stdin:
121-
process.stdin.close()
122-
terminate_process_tree(process.pid, self._logger)
123-
process.wait()
130+
try:
131+
if process.stdout:
132+
process.stdout.close()
133+
if process.stdin:
134+
process.stdin.close()
135+
terminate_process_tree(process.pid, self._logger)
136+
process.wait()
137+
except Exception as e:
138+
self._logger.warning("Error terminating process: %s", e)
124139
if self.data_dir:
125140
self._remove_directory(self.data_dir)
126141

@@ -169,6 +184,12 @@ def _start_process(self):
169184
stderr=stderr,
170185
start_new_session=True, # Isolate process group to prevent kill 0 from affecting parent
171186
)
187+
# Capture the child's process group id for robust cleanup even if parent shell exits
188+
try:
189+
self._pgid = os.getpgid(self.proc.pid)
190+
except Exception as e:
191+
self._logger.warning("Could not get process group id: %s", e)
192+
self._pgid = None
172193
self._sed_proc = subprocess.Popen(
173194
["sed", "-u", f"s/^/[{self._command_name.upper()}] /"],
174195
stdin=self.proc.stdout,
@@ -190,6 +211,12 @@ def _start_process(self):
190211
stderr=stderr,
191212
start_new_session=True, # Isolate process group to prevent kill 0 from affecting parent
192213
)
214+
# Capture the child's process group id for robust cleanup even if parent shell exits
215+
try:
216+
self._pgid = os.getpgid(self.proc.pid)
217+
except Exception as e:
218+
self._logger.warning("Could not get process group id: %s", e)
219+
self._pgid = None
193220

194221
self._sed_proc = subprocess.Popen(
195222
["sed", "-u", f"s/^/[{self._command_name.upper()}] /"],
@@ -198,6 +225,38 @@ def _start_process(self):
198225
)
199226
self._tee_proc = None
200227

228+
def _terminate_process_group(self, timeout: float = 5.0):
229+
"""Terminate the entire process group/session started for the child.
230+
231+
This catches cases where the launcher shell exits and its children are reparented,
232+
leaving no parent PID to traverse, but they remain in the same process group.
233+
"""
234+
if self._pgid is None:
235+
return
236+
try:
237+
self._logger.info("Terminating process group: %s", self._pgid)
238+
os.killpg(self._pgid, signal.SIGTERM)
239+
except ProcessLookupError:
240+
return
241+
except Exception as e:
242+
self._logger.warning(
243+
"Error sending SIGTERM to process group %s: %s", self._pgid, e
244+
)
245+
return
246+
247+
# Give processes a brief moment to exit gracefully
248+
time.sleep(timeout)
249+
250+
# Force kill if anything remains
251+
try:
252+
os.killpg(self._pgid, signal.SIGKILL)
253+
except ProcessLookupError:
254+
pass
255+
except Exception as e:
256+
self._logger.warning(
257+
"Error sending SIGKILL to process group %s: %s", self._pgid, e
258+
)
259+
201260
def _remove_directory(self, path: str) -> None:
202261
"""Remove a directory."""
203262
try:

0 commit comments

Comments
 (0)