Skip to content

Commit

Permalink
refactor: split PyBoxManager and AsyncPyBoxManager
Browse files Browse the repository at this point in the history
  • Loading branch information
edwardzjl committed Dec 17, 2024
1 parent 942c19d commit 9416ff8
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 130 deletions.
7 changes: 5 additions & 2 deletions src/pybox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
#
# SPDX-License-Identifier: Apache-2.0

from pybox.local import AsyncLocalPyBox, LocalPyBox, LocalPyBoxManager
from pybox.remote import RemotePyBox, RemotePyBoxManager
from pybox.local import AsyncLocalPyBox, AsyncLocalPyBoxManager, LocalPyBox, LocalPyBoxManager
from pybox.remote import AsyncRemotePyBox, AsyncRemotePyBoxManager, RemotePyBox, RemotePyBoxManager
from pybox.schema import PyBoxOut

__all__ = [
"AsyncLocalPyBox",
"AsyncLocalPyBoxManager",
"AsyncRemotePyBox",
"AsyncRemotePyBoxManager",
"LocalPyBox",
"LocalPyBoxManager",
"RemotePyBox",
Expand Down
23 changes: 0 additions & 23 deletions src/pybox/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,21 +46,6 @@ def start(
Kernel: An iPython kernel that executes code.
"""

async def astart(
self,
kernel_id: str | None = None,
**kwargs,
) -> BasePyBox:
"""Retrieve an existing kernel or create a new one.
Args:
kernel_id (str): kernel id.
Returns:
Kernel: An iPython kernel that executes code.
"""
return self.start(kernel_id=kernel_id, **kwargs)

@abstractmethod
def shutdown(
self,
Expand All @@ -69,11 +54,3 @@ def shutdown(
) -> None:
"""Shutdown the kernel."""
...

async def ashutdown(
self,
kernel_id: str,
**kwargs,
) -> None:
"""Shutdown the kernel."""
self.shutdown(kernel_id=kernel_id, **kwargs)
22 changes: 12 additions & 10 deletions src/pybox/kube.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,17 @@ def start(self, kernel_id: str | None = None, cwd: str | None = None, **kwargs)

return LocalPyBox(kernel_id=kernel_id, client=kernel_client)

async def astart(self, kernel_id: str | None = None, cwd: str | None = None, **kwargs) -> AsyncLocalPyBox:
def shutdown(self, kernel_id: str, **kwargs) -> None:
"""Shutdown the kernel in kubernetes.
Args:
kernel_id (str): kernel_id
"""
self.client.delete_by_kernel_id(kernel_id, **kwargs)


class AsyncKubePyBoxManager(KubePyBoxManager):
async def start(self, kernel_id: str | None = None, cwd: str | None = None, **kwargs) -> AsyncLocalPyBox:
"""Retrieve an existing kernel or create a new one in kubernetes
Args:
Expand Down Expand Up @@ -87,15 +97,7 @@ async def astart(self, kernel_id: str | None = None, cwd: str | None = None, **k

return AsyncLocalPyBox(kernel_id=kernel_id, client=kernel_client)

def shutdown(self, kernel_id: str, **kwargs) -> None:
"""Shutdown the kernel in kubernetes.
Args:
kernel_id (str): kernel_id
"""
self.client.delete_by_kernel_id(kernel_id, **kwargs)

async def ashutdown(self, kernel_id: str, **kwargs) -> None:
async def shutdown(self, kernel_id: str, **kwargs) -> None:
"""Shutdown the kubernetes kernel by kernel id.
Args:
Expand Down
97 changes: 45 additions & 52 deletions src/pybox/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

from jupyter_client import AsyncKernelManager, AsyncMultiKernelManager, KernelManager, MultiKernelManager
from jupyter_client.multikernelmanager import DuplicateKernelError
from jupyter_core.utils import run_sync

if TYPE_CHECKING:
from types import TracebackType
Expand Down Expand Up @@ -297,32 +296,13 @@ class LocalPyBoxManager(BasePyBoxManager):
def __init__(
self,
kernel_manager: MultiKernelManager | None = None,
async_kernel_manager: AsyncMultiKernelManager | None = None,
profile_dir: str | None = None,
):
self.profile_dir = profile_dir
if kernel_manager is None:
self.kernel_manager = MultiKernelManager()
else:
self.kernel_manager = kernel_manager
if async_kernel_manager is None:
self.async_kernel_manager = AsyncMultiKernelManager()
else:
self.async_kernel_manager = async_kernel_manager

# TODO: It works well in async scenarios but blocks in sync scenarios, so it should be disabled now.
# weakref.finalize(self, self.cleanup)

async def acleanup(self, *, now: bool = False):
"""clean up all the kernels."""
logger.info("Shutting down all sync kernels")
self.shutdown_all(now=now)

# close the async kernels
logger.info("Shutting down all async kernels")
await self.ashutdown_all(now=now)

cleanup = run_sync(acleanup)

def __enter__(self):
return self
Expand All @@ -334,17 +314,6 @@ def __exit__(self, exc_type, exc_value, traceback):

return True

async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc_value, traceback):
await self.acleanup(now=True)

if exc_type is not None:
return False

return True

def start(
self,
kernel_id: str | None = None,
Expand All @@ -364,7 +333,49 @@ def start(
km = self.kernel_manager.get_kernel(kernel_id=kid)
return LocalPyBox(km=km, mkm=self.kernel_manager)

async def astart(
def shutdown(
self,
kernel_id: str,
*,
now: bool = False,
restart: bool = False,
) -> None:
try:
self.kernel_manager.shutdown_kernel(kernel_id=kernel_id, now=now, restart=restart)
except KeyError:
logger.warning("kernel %s not found", kernel_id)
else:
logger.info("Kernel %s shut down", kernel_id)

def shutdown_all(self, *args, **kwargs):
if len(self.kernel_manager):
self.kernel_manager.shutdown_all(*args, **kwargs)


class AsyncLocalPyBoxManager(LocalPyBoxManager):
def __init__(
self,
async_kernel_manager: AsyncMultiKernelManager | None = None,
profile_dir: str | None = None,
):
self.profile_dir = profile_dir
if async_kernel_manager is None:
self.async_kernel_manager = AsyncMultiKernelManager()
else:
self.async_kernel_manager = async_kernel_manager

async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc_value, traceback):
await self.acleanup(now=True)

if exc_type is not None:
return False

return True

async def start(
self,
kernel_id: str | None = None,
**kwargs,
Expand All @@ -384,21 +395,7 @@ async def astart(
km = self.async_kernel_manager.get_kernel(kernel_id=kid)
return AsyncLocalPyBox(km=km, mkm=self.async_kernel_manager)

def shutdown(
self,
kernel_id: str,
*,
now: bool = False,
restart: bool = False,
) -> None:
try:
self.kernel_manager.shutdown_kernel(kernel_id=kernel_id, now=now, restart=restart)
except KeyError:
logger.warning("kernel %s not found", kernel_id)
else:
logger.info("Kernel %s shut down", kernel_id)

async def ashutdown(
async def shutdown(
self,
kernel_id: str,
*,
Expand All @@ -412,10 +409,6 @@ async def ashutdown(
else:
logger.info("Kernel %s shut down", kernel_id)

def shutdown_all(self, *args, **kwargs):
if len(self.kernel_manager):
self.kernel_manager.shutdown_all(*args, **kwargs)

async def ashutdown_all(self, *args, **kwargs):
async def shutdown_all(self, *args, **kwargs):
if len(self.async_kernel_manager):
await self.async_kernel_manager.shutdown_all(*args, **kwargs)
44 changes: 23 additions & 21 deletions src/pybox/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,28 @@ def start(
logger.info("Started kernel with id %s", kernel.id)
return box

async def astart(
def shutdown(
self,
kernel_id: str,
) -> None:
response = requests.delete(urljoin(str(self.host), f"/api/kernels/{kernel_id}"), timeout=60)
if not response.ok:
if response.status_code == requests.codes.not_found:
logger.warning("kernel %s not found", kernel_id)
else:
err_msg = f"Error deleting kernel {kernel_id}: {response.status_code}\n{response.content}"
raise RuntimeError(err_msg)
logger.info("Kernel %s shut down", kernel_id)

def get_ws_url(self, kernel_id: str) -> str:
base = urlparse(self.host)
ws_scheme = "wss" if base.scheme == "https" else "ws"
ws_base = urlunparse(base._replace(scheme=ws_scheme))
return urljoin(ws_base, f"/api/kernels/{kernel_id}/channels")


class AsyncRemotePyBoxManager(RemotePyBoxManager):
async def start(
self,
kernel_id: str | None = None,
cwd: str | None = None,
Expand Down Expand Up @@ -281,20 +302,7 @@ async def astart(
logger.info("Started kernel with id %s", kernel.id)
return box

def shutdown(
self,
kernel_id: str,
) -> None:
response = requests.delete(urljoin(str(self.host), f"/api/kernels/{kernel_id}"), timeout=60)
if not response.ok:
if response.status_code == requests.codes.not_found:
logger.warning("kernel %s not found", kernel_id)
else:
err_msg = f"Error deleting kernel {kernel_id}: {response.status_code}\n{response.content}"
raise RuntimeError(err_msg)
logger.info("Kernel %s shut down", kernel_id)

async def ashutdown(
async def shutdown(
self,
kernel_id: str,
) -> None:
Expand All @@ -308,9 +316,3 @@ async def ashutdown(
err_msg = f"Error deleting kernel {kernel_id}: {response.status}\n{response.content}"
raise RuntimeError(err_msg)
logger.info("Kernel %s shut down", kernel_id)

def get_ws_url(self, kernel_id: str) -> str:
base = urlparse(self.host)
ws_scheme = "wss" if base.scheme == "https" else "ws"
ws_base = urlunparse(base._replace(scheme=ws_scheme))
return urljoin(ws_base, f"/api/kernels/{kernel_id}/channels")
Loading

0 comments on commit 9416ff8

Please sign in to comment.