Skip to content

Commit

Permalink
Add support for external kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Jan 24, 2023
1 parent 827c588 commit 61badd6
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 34 deletions.
3 changes: 3 additions & 0 deletions plugins/kernels/fps_kernels/config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from typing import Optional

from fps.config import PluginModel, get_config # type: ignore
from fps.hooks import register_config # type: ignore


class KernelConfig(PluginModel):
default_kernel: str = "python3"
connection_path: Optional[str] = None


def get_kernel_config():
Expand Down
28 changes: 13 additions & 15 deletions plugins/kernels/fps_kernels/kernel_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,9 @@
from fastapi import WebSocket, WebSocketDisconnect # type: ignore
from starlette.websockets import WebSocketState

from ..kernel_driver.connect import (
cfg_t,
connect_channel,
launch_kernel,
read_connection_file,
)
from ..kernel_driver.connect import cfg_t, connect_channel
from ..kernel_driver.connect import launch_kernel as _launch_kernel
from ..kernel_driver.connect import read_connection_file
from ..kernel_driver.connect import (
write_connection_file as _write_connection_file, # type: ignore
)
Expand Down Expand Up @@ -108,19 +105,20 @@ def allow_messages(self, message_types: Optional[Iterable[str]] = None):
def connections(self) -> int:
return len(self.sessions)

async def start(self) -> None:
if not self.kernelspec_path:
raise RuntimeError("Could not find a kernel, maybe you forgot to install one?")
async def start(self, launch_kernel: bool = True) -> None:
self.last_activity = {
"date": datetime.utcnow().isoformat() + "Z",
"execution_state": "starting",
}
self.kernel_process = await launch_kernel(
self.kernelspec_path,
self.connection_file_path,
self.kernel_cwd,
self.capture_kernel_output,
)
if launch_kernel:
if not self.kernelspec_path:
raise RuntimeError("Could not find a kernel, maybe you forgot to install one?")
self.kernel_process = await _launch_kernel(
self.kernelspec_path,
self.connection_file_path,
self.kernel_cwd,
self.capture_kernel_output,
)
assert self.connection_cfg is not None
identity = uuid.uuid4().hex.encode("ascii")
self.shell_channel = connect_channel("shell", self.connection_cfg, identity=identity)
Expand Down
9 changes: 6 additions & 3 deletions plugins/kernels/fps_kernels/models.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from typing import Optional

from pydantic import BaseModel


class KernelName(BaseModel):
name: str
class KernelInfo(BaseModel):
name: Optional[str] = None
id: Optional[str] = None


class CreateSession(BaseModel):
kernel: KernelName
kernel: KernelInfo
name: str
path: str
type: str
Expand Down
96 changes: 82 additions & 14 deletions plugins/kernels/fps_kernels/routes.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import asyncio
import json
import sys
import uuid
from http import HTTPStatus
from pathlib import Path
from typing import Set, Tuple

from fastapi import APIRouter, Depends, HTTPException, Response
from fastapi.responses import FileResponse
Expand All @@ -11,6 +13,7 @@
from fps_frontend.config import get_frontend_config # type: ignore
from fps_yjs.routes import YDocWebSocketHandler # type: ignore
from starlette.requests import Request # type: ignore
from watchfiles import Change, awatch

from .config import get_kernel_config
from .kernel_driver.driver import KernelDriver # type: ignore
Expand All @@ -29,9 +32,45 @@
prefix_dir: Path = Path(sys.prefix)


async def process_connection_files(changes: Set[Tuple[Change, str]]):
for c in changes:
change, path = c
if change == Change.deleted:
if path in kernels:
del kernels[path]
elif change == Change.added:
try:
data = json.loads(Path(path).read_text())
except BaseException:
continue
if "kernel_name" not in data or "key" not in data:
continue
# looks like a kernel connection file
kernel_name = data["kernel_name"] or "python3"
kernel_id = path
kernels[kernel_id] = {"name": kernel_name, "server": None, "driver": None}


async def watch_connection_files(path: Path):
# first time scan
changes = {(Change.added, str(p)) for p in path.iterdir()}
await process_connection_files(changes)
# then, on every change
async for changes in awatch(path):
await process_connection_files(changes)


@router.on_event("startup")
async def startup():
kernel_config = get_kernel_config()
if kernel_config.connection_path is not None:
path = Path(kernel_config.connection_path)
asyncio.create_task(watch_connection_files(path))


@router.on_event("shutdown")
async def stop_kernels():
for kernel in kernels.values():
for kernel in [kernel for kernel in kernels.values() if kernel["server"] is not None]:
await kernel["server"].stop()


Expand Down Expand Up @@ -74,13 +113,21 @@ async def get_kernels(
):
results = []
for kernel_id, kernel in kernels.items():
if kernel["server"]:
connections = kernel["server"].connections
last_activity = kernel["server"].last_activity["date"]
execution_state = kernel["server"].last_activity["execution_state"]
else:
connections = 1
last_activity = ""
execution_state = "idle"
results.append(
{
"id": kernel_id,
"name": kernel["name"],
"connections": kernel["server"].connections,
"last_activity": kernel["server"].last_activity["date"],
"execution_state": kernel["server"].last_activity["execution_state"],
"connections": connections,
"last_activity": last_activity,
"execution_state": execution_state,
}
)
return results
Expand Down Expand Up @@ -133,14 +180,26 @@ async def create_session(
user: User = Depends(current_user(permissions={"sessions": ["write"]})),
):
create_session = CreateSession(**(await request.json()))
kernel_id = create_session.kernel.id
kernel_name = create_session.kernel.name
kernel_server = KernelServer(
kernelspec_path=Path(find_kernelspec(kernel_name)).as_posix(),
kernel_cwd=str(Path(create_session.path).parent),
)
kernel_id = str(uuid.uuid4())
kernels[kernel_id] = {"name": kernel_name, "server": kernel_server, "driver": None}
await kernel_server.start()
if kernel_name is not None:
kernel_server = KernelServer(
kernelspec_path=Path(find_kernelspec(kernel_name)).as_posix(),
kernel_cwd=str(Path(create_session.path).parent),
)
kernel_id = str(uuid.uuid4())
kernels[kernel_id] = {"name": kernel_name, "server": kernel_server, "driver": None}
await kernel_server.start()
elif kernel_id is not None:
kernel_name = kernels[kernel_id]["name"]
kernel_server = KernelServer(
connection_file=kernel_id,
write_connection_file=False,
)
kernels[kernel_id]["server"] = kernel_server
await kernel_server.start(launch_kernel=False)
else:
return
session_id = str(uuid.uuid4())
session = {
"id": session_id,
Expand All @@ -149,7 +208,7 @@ async def create_session(
"type": create_session.type,
"kernel": {
"id": kernel_id,
"name": create_session.kernel.name,
"name": kernel_name,
"connections": kernel_server.connections,
"last_activity": kernel_server.last_activity["date"],
"execution_state": kernel_server.last_activity["execution_state"],
Expand Down Expand Up @@ -205,7 +264,7 @@ async def execute_cell(
ynotebook.set_cell(execution.cell_idx, cell)


@router.get("/api/kernels/{kernel_id}")
@router.get("/api/kernels/{kernel_id:path}")
async def get_kernel(
kernel_id,
user: User = Depends(current_user(permissions={"kernels": ["read"]})),
Expand All @@ -222,7 +281,7 @@ async def get_kernel(
return result


@router.websocket("/api/kernels/{kernel_id}/channels")
@router.websocket("/api/kernels/{kernel_id:path}/channels")
async def kernel_channels(
kernel_id,
session_id,
Expand All @@ -240,6 +299,15 @@ async def kernel_channels(
accepted_websocket = AcceptedWebSocket(websocket, subprotocol)
if kernel_id in kernels:
kernel_server = kernels[kernel_id]["server"]
if kernel_server is None:
# this is an external kernel
# kernel is already launched, just start a kernel server
kernel_server = KernelServer(
connection_file=kernel_id,
write_connection_file=False,
)
await kernel_server.start(launch_kernel=False)
kernels[kernel_id]["server"] = kernel_server
await kernel_server.serve(accepted_websocket, session_id, permissions)


Expand Down
2 changes: 1 addition & 1 deletion plugins/kernels/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ name = "fps_kernels"
description = "An FPS plugin for the kernels API"
keywords = [ "jupyter", "server", "fastapi", "pluggy", "plugins",]
requires-python = ">=3.7"
dependencies = [ "fps >=0.0.8", "fps-auth-base", "fps-frontend", "fps-yjs", "pyzmq", "websockets", "python-dateutil",]
dependencies = [ "fps >=0.0.8", "fps-auth-base", "fps-frontend", "fps-yjs", "pyzmq", "websockets", "python-dateutil", "watchfiles >=0.16.1,<1"]
dynamic = [ "version",]
[[project.authors]]
name = "Jupyter Development Team"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ async def test_rest_api(start_jupyverse):
),
)
# wait for Y model to be updated
await asyncio.sleep(0.5)
await asyncio.sleep(1)
# retrieve cells
cells = ydoc.get_array("cells").to_json()
assert cells[0]["outputs"] == [
Expand Down

0 comments on commit 61badd6

Please sign in to comment.