Skip to content

Commit

Permalink
Add support for external kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Jan 25, 2023
1 parent 65af765 commit e3ea8d3
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 43 deletions.
3 changes: 3 additions & 0 deletions plugins/kernels/fps_kernels/config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from typing import Optional

from fps.config import PluginModel, get_config # type: ignore
from fps.hooks import register_config # type: ignore


class KernelConfig(PluginModel):
default_kernel: str = "python3"
connection_path: Optional[str] = None


def get_kernel_config():
Expand Down
53 changes: 26 additions & 27 deletions plugins/kernels/fps_kernels/kernel_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,9 @@
from fastapi import WebSocket, WebSocketDisconnect # type: ignore
from starlette.websockets import WebSocketState

from ..kernel_driver.connect import (
cfg_t,
connect_channel,
launch_kernel,
read_connection_file,
)
from ..kernel_driver.connect import cfg_t, connect_channel
from ..kernel_driver.connect import launch_kernel as _launch_kernel
from ..kernel_driver.connect import read_connection_file
from ..kernel_driver.connect import (
write_connection_file as _write_connection_file, # type: ignore
)
Expand Down Expand Up @@ -108,19 +105,20 @@ def allow_messages(self, message_types: Optional[Iterable[str]] = None):
def connections(self) -> int:
return len(self.sessions)

async def start(self) -> None:
if not self.kernelspec_path:
raise RuntimeError("Could not find a kernel, maybe you forgot to install one?")
async def start(self, launch_kernel: bool = True) -> None:
self.last_activity = {
"date": datetime.utcnow().isoformat() + "Z",
"execution_state": "starting",
}
self.kernel_process = await launch_kernel(
self.kernelspec_path,
self.connection_file_path,
self.kernel_cwd,
self.capture_kernel_output,
)
if launch_kernel:
if not self.kernelspec_path:
raise RuntimeError("Could not find a kernel, maybe you forgot to install one?")
self.kernel_process = await _launch_kernel(
self.kernelspec_path,
self.connection_file_path,
self.kernel_cwd,
self.capture_kernel_output,
)
assert self.connection_cfg is not None
identity = uuid.uuid4().hex.encode("ascii")
self.shell_channel = connect_channel("shell", self.connection_cfg, identity=identity)
Expand All @@ -136,17 +134,18 @@ async def start(self) -> None:
]

async def stop(self) -> None:
# FIXME: stop kernel in a better way
try:
self.kernel_process.send_signal(signal.SIGINT)
self.kernel_process.kill()
await self.kernel_process.wait()
except Exception:
pass
try:
os.remove(self.connection_file_path)
except Exception:
pass
if self.write_connection_file:
# FIXME: stop kernel in a better way
try:
self.kernel_process.send_signal(signal.SIGINT)
self.kernel_process.kill()
await self.kernel_process.wait()
except BaseException:
pass
try:
os.remove(self.connection_file_path)
except BaseException:
pass
for task in self.channel_tasks:
task.cancel()
self.channel_tasks = []
Expand Down Expand Up @@ -266,7 +265,7 @@ async def send_to_ws(self, websocket, parts, parent_header, channel_name):
bin_msg = serialize_msg_to_ws_v1(parts, channel_name)
try:
await websocket.websocket.send_bytes(bin_msg)
except Exception:
except BaseException:
pass
# FIXME: update last_activity
# but we don't want to parse the content!
Expand Down
9 changes: 6 additions & 3 deletions plugins/kernels/fps_kernels/models.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from typing import Optional

from pydantic import BaseModel


class KernelName(BaseModel):
name: str
class KernelInfo(BaseModel):
name: Optional[str] = None
id: Optional[str] = None


class CreateSession(BaseModel):
kernel: KernelName
kernel: KernelInfo
name: str
path: str
type: str
Expand Down
115 changes: 104 additions & 11 deletions plugins/kernels/fps_kernels/routes.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import asyncio
import json
import sys
import uuid
from http import HTTPStatus
from pathlib import Path
from typing import Set, Tuple

from fastapi import APIRouter, Depends, HTTPException, Response
from fastapi.responses import FileResponse
Expand All @@ -11,6 +13,7 @@
from fps_frontend.config import get_frontend_config # type: ignore
from fps_yjs.routes import YDocWebSocketHandler # type: ignore
from starlette.requests import Request # type: ignore
from watchfiles import Change, awatch

from .config import get_kernel_config
from .kernel_driver.driver import KernelDriver # type: ignore
Expand All @@ -25,10 +28,67 @@
router = APIRouter()

kernelspecs: dict = {}
kernel_id_to_connection_file = {}
sessions: dict = {}
prefix_dir: Path = Path(sys.prefix)


async def process_connection_files(changes: Set[Tuple[Change, str]]):
# get rid of "simultaneously" added/deleted files
file_changes = {}
for c in changes:
change, path = c
if path not in file_changes:
file_changes[path] = []
file_changes[path].append(change)
to_delete = []
for p, c in file_changes.items():
if Change.added in c and Change.deleted in c:
c.remove(Change.added)
c.remove(Change.deleted)
if not c:
to_delete.append(p)
for path in to_delete:
del file_changes[path]
# process file changes
for path, c in file_changes.items():
for change in c:
if change == Change.deleted:
if path in kernels:
kernel_id = list(kernel_id_to_connection_file.keys())[
list(kernel_id_to_connection_file.values()).index(path)
]
del kernels[kernel_id]
elif change == Change.added:
try:
data = json.loads(Path(path).read_text())
except BaseException:
continue
if "kernel_name" not in data or "key" not in data:
continue
# looks like a kernel connection file
kernel_id = str(uuid.uuid4())
kernel_id_to_connection_file[kernel_id] = path
kernels[kernel_id] = {"name": data["kernel_name"], "server": None, "driver": None}


async def watch_connection_files(path: Path):
# first time scan, treat everything as added files
initial_changes = {(Change.added, str(p)) for p in path.iterdir()}
await process_connection_files(initial_changes)
# then, on every change
async for changes in awatch(path):
await process_connection_files(changes)


@router.on_event("startup")
async def startup():
kernel_config = get_kernel_config()
if kernel_config.connection_path is not None:
path = Path(kernel_config.connection_path)
asyncio.create_task(watch_connection_files(path))


@router.on_event("shutdown")
async def stop_kernels():
for kernel in kernels.values():
Expand Down Expand Up @@ -74,13 +134,21 @@ async def get_kernels(
):
results = []
for kernel_id, kernel in kernels.items():
if kernel["server"]:
connections = kernel["server"].connections
last_activity = kernel["server"].last_activity["date"]
execution_state = kernel["server"].last_activity["execution_state"]
else:
connections = 0
last_activity = ""
execution_state = "idle"
results.append(
{
"id": kernel_id,
"name": kernel["name"],
"connections": kernel["server"].connections,
"last_activity": kernel["server"].last_activity["date"],
"execution_state": kernel["server"].last_activity["execution_state"],
"connections": connections,
"last_activity": last_activity,
"execution_state": execution_state,
}
)
return results
Expand All @@ -95,6 +163,8 @@ async def delete_session(
kernel_server = kernels[kernel_id]["server"]
await kernel_server.stop()
del kernels[kernel_id]
if kernel_id in kernel_id_to_connection_file:
del kernel_id_to_connection_file[kernel_id]
del sessions[session_id]
return Response(status_code=HTTPStatus.NO_CONTENT.value)

Expand Down Expand Up @@ -133,14 +203,28 @@ async def create_session(
user: User = Depends(current_user(permissions={"sessions": ["write"]})),
):
create_session = CreateSession(**(await request.json()))
kernel_id = create_session.kernel.id
kernel_name = create_session.kernel.name
kernel_server = KernelServer(
kernelspec_path=Path(find_kernelspec(kernel_name)).as_posix(),
kernel_cwd=str(Path(create_session.path).parent),
)
kernel_id = str(uuid.uuid4())
kernels[kernel_id] = {"name": kernel_name, "server": kernel_server, "driver": None}
await kernel_server.start()
if kernel_name is not None:
# launch a new ("internal") kernel
kernel_server = KernelServer(
kernelspec_path=Path(find_kernelspec(kernel_name)).as_posix(),
kernel_cwd=str(Path(create_session.path).parent),
)
kernel_id = str(uuid.uuid4())
kernels[kernel_id] = {"name": kernel_name, "server": kernel_server, "driver": None}
await kernel_server.start()
elif kernel_id is not None:
# external kernel
kernel_name = kernels[kernel_id]["name"]
kernel_server = KernelServer(
connection_file=kernel_id_to_connection_file[kernel_id],
write_connection_file=False,
)
kernels[kernel_id]["server"] = kernel_server
await kernel_server.start(launch_kernel=False)
else:
return
session_id = str(uuid.uuid4())
session = {
"id": session_id,
Expand All @@ -149,7 +233,7 @@ async def create_session(
"type": create_session.type,
"kernel": {
"id": kernel_id,
"name": create_session.kernel.name,
"name": kernel_name,
"connections": kernel_server.connections,
"last_activity": kernel_server.last_activity["date"],
"execution_state": kernel_server.last_activity["execution_state"],
Expand Down Expand Up @@ -240,6 +324,15 @@ async def kernel_channels(
accepted_websocket = AcceptedWebSocket(websocket, subprotocol)
if kernel_id in kernels:
kernel_server = kernels[kernel_id]["server"]
if kernel_server is None:
# this is an external kernel
# kernel is already launched, just start a kernel server
kernel_server = KernelServer(
connection_file=kernel_id,
write_connection_file=False,
)
await kernel_server.start(launch_kernel=False)
kernels[kernel_id]["server"] = kernel_server
await kernel_server.serve(accepted_websocket, session_id, permissions)


Expand Down
2 changes: 1 addition & 1 deletion plugins/kernels/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ name = "fps_kernels"
description = "An FPS plugin for the kernels API"
keywords = [ "jupyter", "server", "fastapi", "pluggy", "plugins",]
requires-python = ">=3.7"
dependencies = [ "fps >=0.0.8", "fps-auth-base", "fps-frontend", "fps-yjs", "pyzmq", "websockets", "python-dateutil",]
dependencies = [ "fps >=0.0.8", "fps-auth-base", "fps-frontend", "fps-yjs", "pyzmq", "websockets", "python-dateutil", "watchfiles >=0.16.1,<1"]
dynamic = [ "version",]
[[project.authors]]
name = "Jupyter Development Team"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ async def test_rest_api(start_jupyverse):
),
)
# wait for Y model to be updated
await asyncio.sleep(0.5)
await asyncio.sleep(1)
# retrieve cells
cells = ydoc.get_array("cells").to_json()
assert cells[0]["outputs"] == [
Expand Down

0 comments on commit e3ea8d3

Please sign in to comment.