Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for async kernel management #794

Merged
merged 9 commits into from
Jul 10, 2020
3 changes: 2 additions & 1 deletion enterprise_gateway/enterprisegatewayapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Distributed under the terms of the Modified BSD License.
"""Enterprise Gateway Jupyter application."""

import asyncio
import errno
import getpass
import logging
Expand Down Expand Up @@ -299,7 +300,7 @@ def shutdown(self):
"""Shuts down all running kernels."""
kids = self.kernel_manager.list_kernel_ids()
for kid in kids:
self.kernel_manager.shutdown_kernel(kid, now=True)
asyncio.get_event_loop().run_until_complete(self.kernel_manager.shutdown_kernel(kid, now=True))

def stop(self):
"""
Expand Down
4 changes: 2 additions & 2 deletions enterprise_gateway/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,12 +499,12 @@ def dynamic_config_interval_changed(self, event):
)

kernel_manager_class = Type(
klass="notebook.services.kernels.kernelmanager.MappingKernelManager",
klass="enterprise_gateway.services.kernels.remotemanager.RemoteMappingKernelManager",
default_value="enterprise_gateway.services.kernels.remotemanager.RemoteMappingKernelManager",
config=True,
help="""
The kernel manager class to use. Must be a subclass
of `notebook.services.kernels.MappingKernelManager`.
of `enterprise_gateway.services.kernels.RemoteMappingKernelManager`.
"""
)

Expand Down
41 changes: 20 additions & 21 deletions enterprise_gateway/services/kernels/remotemanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,11 @@
import uuid
import zmq

from tornado import gen, web
from tornado import web
from ipython_genutils.py3compat import unicode_type
from ipython_genutils.importstring import import_item
from notebook.services.kernels.kernelmanager import MappingKernelManager
from notebook.utils import maybe_future
from jupyter_client.ioloop.manager import IOLoopKernelManager
from notebook.services.kernels.kernelmanager import AsyncMappingKernelManager
from jupyter_client.ioloop.manager import AsyncIOLoopKernelManager
from traitlets import directional_link, default, Bool, log as traitlets_log

from ..processproxies.processproxy import LocalProcessProxy, RemoteProcessProxy
Expand Down Expand Up @@ -88,9 +87,8 @@ def new_kernel_id(**kwargs):
return kernel_id


class RemoteMappingKernelManager(MappingKernelManager):
"""Extends the MappingKernelManager with support for managing remote kernels via the process-proxy. """

class RemoteMappingKernelManager(AsyncMappingKernelManager):
"""Extends the AsyncMappingKernelManager with support for managing remote kernels via the process-proxy. """
def _kernel_manager_class_default(self):
return 'enterprise_gateway.services.kernels.remotemanager.RemoteKernelManager'

Expand All @@ -105,8 +103,7 @@ def _refresh_kernel(self, kernel_id):
self.parent.kernel_session_manager.load_session(kernel_id)
return self.parent.kernel_session_manager.start_session(kernel_id)

@gen.coroutine
def start_kernel(self, *args, **kwargs):
async def start_kernel(self, *args, **kwargs):
"""Starts a kernel for a session and return its kernel_id.

Returns
Expand All @@ -118,9 +115,9 @@ def start_kernel(self, *args, **kwargs):
username = KernelSessionManager.get_kernel_username(**kwargs)
self.log.debug("RemoteMappingKernelManager.start_kernel: {kernel_name}, kernel_username: {username}".
format(kernel_name=kwargs['kernel_name'], username=username))
kernel_id = yield maybe_future(super(RemoteMappingKernelManager, self).start_kernel(*args, **kwargs))
kernel_id = await super(RemoteMappingKernelManager, self).start_kernel(*args, **kwargs)
self.parent.kernel_session_manager.create_session(kernel_id, **kwargs)
raise gen.Return(kernel_id)
return kernel_id

def remove_kernel(self, kernel_id):
""" Removes the kernel associated with `kernel_id` from the internal map and deletes the kernel session. """
Expand Down Expand Up @@ -210,8 +207,8 @@ def new_kernel_id(self, **kwargs):
return new_kernel_id(kernel_id_fn=super(RemoteMappingKernelManager, self).new_kernel_id, log=self.log)


class RemoteKernelManager(EnterpriseGatewayConfigMixin, IOLoopKernelManager):
"""Extends the IOLoopKernelManager used by the MappingKernelManager.
class RemoteKernelManager(EnterpriseGatewayConfigMixin, AsyncIOLoopKernelManager):
"""Extends the AsyncIOLoopKernelManager used by the RemoteMappingKernelManager.

This class is responsible for detecting that a remote kernel is desired, then launching the
appropriate class (previously pulled from the kernel spec). The process 'proxy' is
Expand Down Expand Up @@ -289,7 +286,7 @@ def _link_dependent_props(self):
]
self._links = [directional_link((eg_instance, prop), (self, prop)) for prop in dependent_props]

def start_kernel(self, **kwargs):
async def start_kernel(self, **kwargs):
"""Starts a kernel in a separate process.

Where the started kernel resides depends on the configured process proxy.
Expand All @@ -302,7 +299,7 @@ def start_kernel(self, **kwargs):
"""
self._get_process_proxy()
self._capture_user_overrides(**kwargs)
super(RemoteKernelManager, self).start_kernel(**kwargs)
await super(RemoteKernelManager, self).start_kernel(**kwargs)

def _capture_user_overrides(self, **kwargs):
"""
Expand Down Expand Up @@ -338,7 +335,7 @@ def from_ns(match):
return [pat.sub(from_ns, arg) for arg in cmd]
return cmd

def _launch_kernel(self, kernel_cmd, **kwargs):
async def _launch_kernel(self, kernel_cmd, **kwargs):
# Note: despite the under-bar prefix to this method, the jupyter_client comment says that
# this method should be "[overridden] in a subclass to launch kernel subprocesses differently".
# So that's what we've done.
Expand All @@ -357,7 +354,8 @@ def _launch_kernel(self, kernel_cmd, **kwargs):
del env['KG_AUTH_TOKEN']

self.log.debug("Launching kernel: {} with command: {}".format(self.kernel_spec.display_name, kernel_cmd))
return self.process_proxy.launch_process(kernel_cmd, **kwargs)
proxy = await self.process_proxy.launch_process(kernel_cmd, **kwargs)
return proxy

def request_shutdown(self, restart=False):
""" Send a shutdown request via control channel and process proxy (if remote). """
Expand All @@ -368,7 +366,7 @@ def request_shutdown(self, restart=False):
if isinstance(self.process_proxy, RemoteProcessProxy):
self.process_proxy.shutdown_listener()

def restart_kernel(self, now=False, **kwargs):
async def restart_kernel(self, now=False, **kwargs):
"""Restarts a kernel with the arguments that were used to launch it.

This is an automatic restart request (now=True) AND this is associated with a
Expand Down Expand Up @@ -401,7 +399,7 @@ def restart_kernel(self, now=False, **kwargs):
# Use the parent mapping kernel manager so activity monitoring and culling is also shutdown
self.mapping_kernel_manager.shutdown_kernel(kernel_id, now=now)
return
super(RemoteKernelManager, self).restart_kernel(now, **kwargs)
await super(RemoteKernelManager, self).restart_kernel(now, **kwargs)
if isinstance(self.process_proxy, RemoteProcessProxy): # for remote kernels...
# Re-establish activity watching...
if self._activity_stream:
Expand All @@ -414,7 +412,7 @@ def restart_kernel(self, now=False, **kwargs):
self.kernel_session_manager.refresh_session(kernel_id)
self.restarting = False

def signal_kernel(self, signum):
async def signal_kernel(self, signum):
"""Sends signal `signum` to the kernel process. """
if self.has_kernel:
if signum == signal.SIGINT:
Expand Down Expand Up @@ -456,7 +454,8 @@ def cleanup(self, connection_file=True):
if self.process_proxy:
self.process_proxy.cleanup()
self.process_proxy = None
return super(RemoteKernelManager, self).cleanup(connection_file)

super(RemoteKernelManager, self).cleanup(connection_file)

def write_connection_file(self):
"""Write connection info to JSON dict in self.connection_file if the kernel is local.
Expand Down
31 changes: 15 additions & 16 deletions enterprise_gateway/services/processproxies/conductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
# Distributed under the terms of the Modified BSD License.
"""Code related to managing kernels running in Conductor clusters."""

import asyncio
import os
import signal
import json
import time
import subprocess
import socket
import re
import signal
import socket
import subprocess

from jupyter_client import launch_kernel, localinterfaces

Expand All @@ -34,9 +34,9 @@ def __init__(self, kernel_manager, proxy_config):
self.conductor_endpoint = proxy_config.get('conductor_endpoint',
kernel_manager.conductor_endpoint)

def launch_process(self, kernel_cmd, **kwargs):
async def launch_process(self, kernel_cmd, **kwargs):
"""Launches the specified process within a Conductor cluster environment."""
super(ConductorClusterProcessProxy, self).launch_process(kernel_cmd, **kwargs)
await super(ConductorClusterProcessProxy, self).launch_process(kernel_cmd, **kwargs)
# Get cred from process env
env_dict = dict(os.environ.copy())
if env_dict and 'EGO_SERVICE_CREDENTIAL' in env_dict:
Expand All @@ -55,8 +55,7 @@ def launch_process(self, kernel_cmd, **kwargs):
self.env = kwargs.get('env')
self.log.debug("Conductor cluster kernel launched using Conductor endpoint: {}, pid: {}, Kernel ID: {}, "
"cmd: '{}'".format(self.conductor_endpoint, self.local_proc.pid, self.kernel_id, kernel_cmd))
self.confirm_remote_startup()

await self.confirm_remote_startup()
return self

def _update_launch_info(self, kernel_cmd, **kwargs):
Expand Down Expand Up @@ -114,7 +113,7 @@ def send_signal(self, signum):
else:
return super(ConductorClusterProcessProxy, self).send_signal(signum)

def kill(self):
async def kill(self):
"""Kill a kernel.
:return: None if the application existed and is not in RUNNING state, False otherwise.
"""
Expand All @@ -127,7 +126,7 @@ def kill(self):
i = 1
state = self._query_app_state_by_driver_id(self.driver_id)
while state not in ConductorClusterProcessProxy.final_states and i <= max_poll_attempts:
time.sleep(poll_interval)
await asyncio.sleep(poll_interval)
state = self._query_app_state_by_driver_id(self.driver_id)
i = i + 1

Expand Down Expand Up @@ -173,7 +172,7 @@ def _parse_driver_submission_id(self, submission_response):
self.driver_id = driver_id[0]
self.log.debug("Driver ID: {}".format(driver_id[0]))

def confirm_remote_startup(self):
async def confirm_remote_startup(self):
""" Confirms the application is in a started state before returning. Should post-RUNNING states be
unexpectedly encountered ('FINISHED', 'KILLED', 'RECLAIMED') then we must throw, otherwise the rest
of the gateway will believe its talking to a valid kernel.
Expand All @@ -187,7 +186,7 @@ def confirm_remote_startup(self):
output = self.local_proc.stderr.read().decode("utf-8")
self._parse_driver_submission_id(output)
i += 1
self.handle_timeout()
await self.handle_timeout()

if self._get_application_id(True):
# Once we have an application ID, start monitoring state, obtain assigned host and get connection info
Expand All @@ -203,7 +202,7 @@ def confirm_remote_startup(self):
format(i, app_state, self.assigned_host, self.kernel_id, self.application_id))

if self.assigned_host != '':
ready_to_connect = self.receive_connection_info()
ready_to_connect = await self.receive_connection_info()
else:
self.detect_launch_failure()

Expand All @@ -223,9 +222,9 @@ def _get_application_state(self):
self.assigned_ip = socket.gethostbyname(self.assigned_host)
return app_state

def handle_timeout(self):
async def handle_timeout(self):
"""Checks to see if the kernel launch timeout has been exceeded while awaiting connection info."""
time.sleep(poll_interval)
await asyncio.sleep(poll_interval)
time_interval = RemoteProcessProxy.get_time_diff(self.start_time, RemoteProcessProxy.get_current_time())

if time_interval > self.kernel_launch_timeout:
Expand All @@ -240,7 +239,7 @@ def handle_timeout(self):
else:
reason = "App {} is WAITING, but waited too long ({} secs) to get connection file". \
format(self.application_id, self.kernel_launch_timeout)
self.kill()
await self.kill()
timeout_message = "KernelID: '{}' launch timeout due to: {}".format(self.kernel_id, reason)
self.log_and_raise(http_status_code=error_http_code, reason=timeout_message)

Expand Down
16 changes: 7 additions & 9 deletions enterprise_gateway/services/processproxies/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
# Distributed under the terms of the Modified BSD License.
"""Code related to managing kernels running in containers."""

import abc
import os
import signal
import abc

import urllib3 # docker ends up using this and it causes lots of noise, so turn off warnings

from jupyter_client import launch_kernel, localinterfaces
Expand Down Expand Up @@ -51,7 +50,7 @@ def _determine_kernel_images(self, proxy_config):
self.kernel_executor_image = proxy_config.get('executor_image_name')
self.kernel_executor_image = os.environ.get('KERNEL_EXECUTOR_IMAGE', self.kernel_executor_image)

def launch_process(self, kernel_cmd, **kwargs):
async def launch_process(self, kernel_cmd, **kwargs):
"""Launches the specified process within the container environment."""
# Set env before superclass call so we see these in the debug output

Expand All @@ -64,7 +63,7 @@ def launch_process(self, kernel_cmd, **kwargs):

self._enforce_uid_gid_blacklists(**kwargs)

super(ContainerProcessProxy, self).launch_process(kernel_cmd, **kwargs)
await super(ContainerProcessProxy, self).launch_process(kernel_cmd, **kwargs)

self.local_proc = launch_kernel(kernel_cmd, **kwargs)
self.pid = self.local_proc.pid
Expand All @@ -73,8 +72,7 @@ def launch_process(self, kernel_cmd, **kwargs):
self.log.info("{}: kernel launched. Kernel image: {}, KernelID: {}, cmd: '{}'"
.format(self.__class__.__name__, self.kernel_image, self.kernel_id, kernel_cmd))

self.confirm_remote_startup()

await self.confirm_remote_startup()
return self

def _enforce_uid_gid_blacklists(self, **kwargs):
Expand Down Expand Up @@ -153,19 +151,19 @@ def cleanup(self):
self.kill()
super(ContainerProcessProxy, self).cleanup()

def confirm_remote_startup(self):
async def confirm_remote_startup(self):
"""Confirms the container has started and returned necessary connection information."""
self.start_time = RemoteProcessProxy.get_current_time()
i = 0
ready_to_connect = False # we're ready to connect when we have a connection file to use
while not ready_to_connect:
i += 1
self.handle_timeout()
await self.handle_timeout()

container_status = self.get_container_status(str(i))
if container_status:
if self.assigned_host != '':
ready_to_connect = self.receive_connection_info()
ready_to_connect = await self.receive_connection_info()
self.pid = 0 # We won't send process signals for kubernetes lifecycle management
self.pgid = 0
else:
Expand Down
Loading