From 35c0a2078302dee25af1333352a29b6afa0678a6 Mon Sep 17 00:00:00 2001 From: Zeb Burke-Conte Date: Sat, 3 Aug 2024 22:00:59 -0700 Subject: [PATCH] Feature: Fork kernel --- ipykernel/ipkernel.py | 31 +++++++------ ipykernel/kernelapp.py | 95 +++++++++++++++++++++++++++++++++++++- ipykernel/kernelbase.py | 23 ++++++++- tests/test_kernel.py | 45 ++++++++++++++++++ tests/test_message_spec.py | 6 +++ tests/utils.py | 13 ++++++ 6 files changed, 196 insertions(+), 17 deletions(-) diff --git a/ipykernel/ipkernel.py b/ipykernel/ipkernel.py index db83d986f..283bccb6a 100644 --- a/ipykernel/ipkernel.py +++ b/ipykernel/ipkernel.py @@ -117,6 +117,23 @@ def __init__(self, **kwargs): self.debug_just_my_code, ) + self.init_shell() + + if _use_appnope() and self._darwin_app_nap: + # Disable app-nap as the kernel is not a gui but can have guis + import appnope # type:ignore[import-untyped] + + appnope.nope() + + self._new_threads_parent_header = {} + self._initialize_thread_hooks() + + if hasattr(gc, "callbacks"): + # while `gc.callbacks` exists since Python 3.3, pypy does not + # implement it even as of 3.9. + gc.callbacks.append(self._clean_thread_parent_frames) + + def init_shell(self): # Initialize the InteractiveShell subclass self.shell = self.shell_class.instance( parent=self, @@ -145,20 +162,6 @@ def __init__(self, **kwargs): for msg_type in comm_msg_types: self.shell_handlers[msg_type] = getattr(self.comm_manager, msg_type) - if _use_appnope() and self._darwin_app_nap: - # Disable app-nap as the kernel is not a gui but can have guis - import appnope # type:ignore[import-untyped] - - appnope.nope() - - self._new_threads_parent_header = {} - self._initialize_thread_hooks() - - if hasattr(gc, "callbacks"): - # while `gc.callbacks` exists since Python 3.3, pypy does not - # implement it even as of 3.9. - gc.callbacks.append(self._clean_thread_parent_frames) - help_links = List( [ { diff --git a/ipykernel/kernelapp.py b/ipykernel/kernelapp.py index 98b08b845..d04fd915c 100644 --- a/ipykernel/kernelapp.py +++ b/ipykernel/kernelapp.py @@ -726,8 +726,99 @@ def start(self) -> None: if self.poller is not None: self.poller.start() backend = "trio" if self.trio_loop else "asyncio" - run(self.main, backend=backend) - return + + while True: + run(self.main, backend=backend) + if not getattr(self.kernel, "_fork_requested", False): + break + self.fork() + + def fork(self): + # HACK: Why is this necessary? + # Without it, the *parent* kernel doesn't work. + # Also, it doesn't work if I try to start it again with + # self.init_iopub()... + self.iopub_thread.stop() + + # Create a temporary connection file that will be inherited by the child process. + connection_file, conn = write_connection_file() + + parent_pid = os.getpid() + pid = os.fork() + self.kernel._fork_requested = False # reset for parent AND child + if pid == 0: + self.log.debug("Child kernel with pid %s", os.getpid()) + + # close all sockets and ioloops + self.close() + + # Reset all ports so they will be reinitialized with the ports from the connection file + for name in [ + "%s_port" % channel for channel in ("shell", "stdin", "iopub", "hb", "control") + ]: + setattr(self, name, 0) + self.connection_file = connection_file + + # Reset the ZMQ context for it to be recreated + self.context = None + + # Make ParentPoller work correctly (the new process is a child of the previous kernel) + self.parent_handle = parent_pid + + # Session have a protection to send messages from forked processes through the `check_pid` flag. + self.session.pid = os.getpid() + self.session.key = conn["key"].encode() + + self.init_connection_file() + self.init_poller() + self.init_sockets() + self.init_heartbeat() + self.init_io() + + kernel = self.kernel + params = dict( + parent=self, + session=self.session, + control_socket=self.control_socket, + control_thread=self.control_thread, + debugpy_socket=self.debugpy_socket, + debug_shell_socket=self.debug_shell_socket, + shell_socket=self.shell_socket, + iopub_thread=self.iopub_thread, + iopub_socket=self.iopub_socket, + stdin_socket=self.stdin_socket, + log=self.log, + profile_dir=self.profile_dir, + ) + for k, v in params.items(): + setattr(kernel, k, v) + + kernel.user_ns = kernel.shell.user_ns + kernel.init_shell() + + kernel.record_ports({name + "_port": port for name, port in self._ports.items()}) + self.kernel = kernel + + # Allow the displayhook to get the execution count + self.displayhook.get_execution_count = lambda: kernel.execution_count + + # shell init steps + self.init_shell() + if self.shell: + self.init_gui_pylab() + self.init_extensions() + self.init_code() + # flush stdout/stderr, so that anything written to these streams during + # initialization do not get associated with the first execution request + sys.stdout.flush() + sys.stderr.flush() + self.start() + else: + self.log.debug("Parent kernel will resume") + # keep a reference, since the will set this to None + post_fork_callback = self.kernel._post_fork_callback + post_fork_callback(pid, conn) + self.kernel._post_fork_callback = None async def main(self): async with create_task_group() as tg: diff --git a/ipykernel/kernelbase.py b/ipykernel/kernelbase.py index e507964b2..1ff57ca3e 100644 --- a/ipykernel/kernelbase.py +++ b/ipykernel/kernelbase.py @@ -217,6 +217,7 @@ def _parent_header(self): "shutdown_request", "is_complete_request", "interrupt_request", + "fork", # deprecated: "apply_request", ] @@ -229,6 +230,25 @@ def _parent_header(self): "usage_request", ] + def fork(self, stream, ident, parent): + # Forking in the (async)io loop is not supported. + # instead, we stop it, and use the io loop to pass + # information up the callstack + # loop = ioloop.IOLoop.current() + self._fork_requested = True + + def post_fork_callback(pid, conn): + reply_content = json_clean({"status": "ok", "pid": pid, "conn": conn}) + metadata = {} + metadata = self.finish_metadata(parent, metadata, reply_content) + + self.session.send( + stream, "fork_reply", reply_content, parent, metadata=metadata, ident=ident + ) + + self._post_fork_callback = post_fork_callback + self.stop() + def __init__(self, **kwargs): """Initialize the kernel.""" super().__init__(**kwargs) @@ -469,7 +489,8 @@ async def start(self, *, task_status: TaskStatus = TASK_STATUS_IGNORED) -> None: if not self._is_test and self.control_socket is not None: if self.control_thread: self.control_thread.set_task(self.control_main) - self.control_thread.start() + if not self.control_thread.is_alive(): + self.control_thread.start() else: tg.start_soon(self.control_main) diff --git a/tests/test_kernel.py b/tests/test_kernel.py index 88f02ae9a..2c20ba934 100644 --- a/tests/test_kernel.py +++ b/tests/test_kernel.py @@ -23,6 +23,7 @@ from .utils import ( TIMEOUT, assemble_output, + connect_to_kernel, execute, flush_channels, get_reply, @@ -491,6 +492,50 @@ def test_shutdown(): assert not km.is_alive() +def test_fork_metadata(): + with new_kernel() as kc: + from .test_message_spec import validate_message + + km = kc.parent + fork_msg_id = kc.fork() + fork_reply = kc.get_shell_msg(timeout=TIMEOUT) + validate_message(fork_reply, "fork_reply", fork_msg_id) + assert fork_msg_id == fork_reply["parent_header"]["msg_id"] == fork_msg_id + assert fork_reply["content"]["conn"]["key"] != kc.session.key.decode() + fork_pid = fork_reply["content"]["pid"] + _check_status(fork_reply["content"]) + wait_for_idle(kc) + + assert fork_pid != km.provisioner.pid + # TODO: Inspect if `fork_pid` is running? Might need to use `psutil` for this in order to be cross platform + + with connect_to_kernel(fork_reply["content"]["conn"], TIMEOUT) as kc_fork: + assert fork_reply["content"]["conn"]["key"] == kc_fork.session.key.decode() + kc_fork.shutdown() + + +def test_fork(): + def execute_with_user_expression(kc, code, user_expression): + _, reply = execute(code, kc=kc, user_expressions={"my-user-expression": user_expression}) + content = reply["user_expressions"]["my-user-expression"]["data"]["text/plain"] + wait_for_idle(kc) + return content + + """Kernel forks after fork_request""" + with kernel() as kc: + assert execute_with_user_expression(kc, "a = 1", "a") == "1" + assert execute_with_user_expression(kc, "b = 2", "b") == "2" + kc.fork() + fork_reply = kc.get_shell_msg(timeout=TIMEOUT) + wait_for_idle(kc) + + with connect_to_kernel(fork_reply["content"]["conn"], TIMEOUT) as kc_fork: + assert execute_with_user_expression(kc_fork, "a = 11", "a, b") == str((11, 2)) + assert execute_with_user_expression(kc_fork, "b = 12", "a, b") == str((11, 12)) + assert execute_with_user_expression(kc, "z = 20", "a, b") == str((1, 2)) + kc_fork.shutdown() + + def test_interrupt_during_input(): """ The kernel exits after being interrupted while waiting in input(). diff --git a/tests/test_message_spec.py b/tests/test_message_spec.py index d98503ee7..c5ecbde03 100644 --- a/tests/test_message_spec.py +++ b/tests/test_message_spec.py @@ -208,6 +208,11 @@ class IsCompleteReplyIncomplete(Reference): indent = Unicode() +class ForkReply(Reply): + pid = Integer() + conn = Dict() + + # IOPub messages @@ -255,6 +260,7 @@ class HistoryReply(Reply): "stream": Stream(), "display_data": DisplayData(), "header": RHeader(), + "fork_reply": ForkReply(), } # ----------------------------------------------------------------------------- diff --git a/tests/utils.py b/tests/utils.py index b1b4119f0..a7b37871c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -212,3 +212,16 @@ def __enter__(self): def __exit__(self, exc, value, tb): os.chdir(self.old_wd) return super().__exit__(exc, value, tb) + + +@contextmanager +def connect_to_kernel(connection_info, timeout): + from jupyter_client import BlockingKernelClient + + kc = BlockingKernelClient() + kc.log.setLevel("DEBUG") + kc.load_connection_info(connection_info) + kc.start_channels() + kc.wait_for_ready(timeout) + yield kc + kc.stop_channels()