From 107cfff2607de55213b201d7e86b72f217658069 Mon Sep 17 00:00:00 2001 From: Russ Allbery Date: Fri, 24 Mar 2023 12:55:36 -0700 Subject: [PATCH 1/6] Do all spawning work in start The original idea to have start return immediately once the lab start has been initiated doesn't work with JupyterHub's assumptions about spawners. Its timeouts and error handling expect all of the work to happen in the start method, and progress must not raise exceptions or JupyterHub reports uncaught exceptions and breaks its UI and API. Follow the design of KubeSpawner and have the start method hold a copy of its task and do all the progress monitoring and event creation. The progress method then just waits for that task to complete, reporting all the events it generates as an iterator. Do a bit of refactoring to move spawn events to a separate class, and add severity to the start of the message similar to what KubeSpawner does. --- src/rsp_restspawner/errors.py | 28 +-- src/rsp_restspawner/spawner.py | 323 ++++++++++++++++++++++++++------- tests/spawner_test.py | 8 +- tests/support/controller.py | 2 +- tox.ini | 2 +- 5 files changed, 271 insertions(+), 92 deletions(-) diff --git a/src/rsp_restspawner/errors.py b/src/rsp_restspawner/errors.py index 8ff7b59..d7bce55 100644 --- a/src/rsp_restspawner/errors.py +++ b/src/rsp_restspawner/errors.py @@ -1,26 +1,18 @@ -from httpx import Response +"""Exceptions for the RSP REST spawner. +JupyterHub catches all exceptions derived from `Exception` and treats them the +same, so the distinction between exceptions is just for better error reporting +and improved code readability. +""" -class SpawnerError(Exception): - def __init__(self, response: Response) -> None: - self._response = response - def __str__(self) -> str: - r = self._response - sc = r.status_code - rp = r.reason_phrase - txt = r.text - url = r.url - return f"Request for {url}: status code {sc} ({rp}): '{txt}'" +class InvalidAuthStateError(Exception): + """The JupyterHub auth state for the user contains no token.""" class MissingFieldError(Exception): - pass - + """The reply from the lab controller is missing a required field.""" -class EventError(Exception): - pass - -class InvalidAuthStateError(Exception): - """The JupyterHub auth state for the user contains no token.""" +class SpawnFailedError(Exception): + """The lab controller reports that the spawn failed.""" diff --git a/src/rsp_restspawner/spawner.py b/src/rsp_restspawner/spawner.py index 241b4b1..7755c50 100644 --- a/src/rsp_restspawner/spawner.py +++ b/src/rsp_restspawner/spawner.py @@ -1,17 +1,21 @@ """Spawner class that uses a REST API to a separate Kubernetes service.""" +from __future__ import annotations + +import asyncio from collections.abc import AsyncIterator +from dataclasses import dataclass from datetime import timedelta from enum import Enum from pathlib import Path -from typing import Optional +from typing import Any, Optional from httpx import AsyncClient from httpx_sse import ServerSentEvent, aconnect_sse from jupyterhub.spawner import Spawner from traitlets import Unicode, default -from .errors import InvalidAuthStateError, MissingFieldError, SpawnerError +from .errors import InvalidAuthStateError, MissingFieldError, SpawnFailedError __all__ = [ "LabStatus", @@ -35,6 +39,52 @@ class LabStatus(str, Enum): FAILED = "failed" +@dataclass(frozen=True, slots=True) +class SpawnEvent: + """JupyterHub spawning event.""" + + progress: int + """Percentage of progress, from 0 to 100.""" + + message: str + """Event description.""" + + complete: bool = False + """Whether the event indicated spawning is done.""" + + failed: bool = False + """Whether the event indicated spawning failed.""" + + @classmethod + def from_sse(cls, sse: ServerSentEvent, progress: int) -> SpawnEvent: + """Convert from a server-sent event from the lab controller. + + Parameters + ---------- + sse + Event from the lab controller. + progress + Current progress percentage. Parsing of the progress events that + communicate this must be done outside of this class. + """ + if sse.event == "complete": + message = "[info] " + (sse.data or "Lab pod successfully spawned") + return cls(progress=90, message=message, complete=True) + elif sse.event == "info": + return cls(progress=progress, message=f"[info] {sse.data}") + elif sse.event == "error": + return cls(progress=progress, message=f"[error] {sse.data}") + elif sse.event == "failed": + message = f"[error] {sse.data}" + return cls(progress=progress, message=message, failed=True) + else: + return cls(progress=progress, message=f"[unknown] {sse.data}") + + def to_dict(self) -> dict[str, int | str]: + """Convert to the dictionary expected by JupyterHub.""" + return {"progress": self.progress, "message": self.message} + + class RSPRestSpawner(Spawner): """Spawner class that sends requests to the RSP lab controller. @@ -87,6 +137,17 @@ class RSPRestSpawner(Spawner): def _env_keep_default(self) -> list[str]: return [] + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + # Holds the events from a spawn in progress. + self._events: list[SpawnEvent] = [] + + # Holds the future representing a spawn in progress, used by the + # progress method to know when th spawn is finished and it should + # exit. + self._start_future: Optional[asyncio.Task] = None + @property def _client(self) -> AsyncClient: """Shared `httpx.AsyncClient`.""" @@ -95,32 +156,114 @@ def _client(self) -> AsyncClient: _CLIENT = AsyncClient() return _CLIENT - async def start(self) -> str: - """Returns expected URL of running pod - (returns before creation completes).""" - r = await self._client.post( - self._controller_url("labs", self.user.name, "create"), - headers=await self._user_authorization(), - json={ - "options": self.options_from_form(self.user_options), - "env": self.get_env(), - }, - timeout=self.start_timeout, - follow_redirects=False, - ) - if r.status_code == 409 or r.status_code == 303: - # For the Conflict we need to check the status ourself. - # This route requires an admin token - r = await self._client.get( - self._controller_url("labs", self.user.name), - headers=self._admin_authorization(), + def start(self) -> asyncio.Task: + """Start the user's pod. + + Initiates the pod start operation and then waits for the pod to spawn + by watching the event stream, converting those events into the format + expected by JupyterHub and returned by `progress`. Returns only when + the pod is running and JupyterHub should start waiting for the lab + process to start responding. + + Returns + ------- + asyncio.Task + Running task monitoring the progress of the spawn. This task will + be started before it is returned. When the task is complete, it + will return the cluster-internal URL of the running Jupyter lab + process. + + Notes + ----- + The actual work is done in `_start`. This is a tiny wrapper to do + bookkeeping on the event stream and record the running task so that + `progress` can notice when the task is complete and return. + + It is tempting to only initiate the pod spawn here, return + immediately, and then let JupyterHub follow progress via the + `progress` API. However, this is not what JupyterHub is expecting. + The entire spawn process must happen before the `start` method returns + for the configured timeouts to work properly; once `start` has + returned, JupyterHub only allows a much shorter timeout for the lab to + fully start. + + In addition, JupyterHub handles exceptions from `start` and correctly + recognizes that the pod has failed to start, but exceptions from + `progress` are treated as uncaught exceptions and cause the UI to + break. Therefore, `progress` must never fail and all operations that + may fail need to be done in `start`. + """ + self._start_future = asyncio.create_task(self._start()) + return self._start_future + + async def _start(self) -> str: + """Spawn the user's lab. + + This is the core of the work of `start`. Ask the lab controller to + create the lab and monitor its progress, generating events that are + stored in the ``_events`` attribute for `progress`. + + Returns + ------- + str + Cluster-internal URL of the running Jupyter lab process. + + Notes + ----- + JupyterHub itself arranges for two spawns for the same spawner object + to not be running at the same time, so we ignore that possibility. + """ + self._events = [] + try: + r = await self._client.post( + self._controller_url("labs", self.user.name, "create"), + headers=await self._user_authorization(), + json={ + "options": self.options_from_form(self.user_options), + "env": self.get_env(), + }, + timeout=self.start_timeout, ) - if r.status_code == 200: - obj = r.json() - if "internal_url" in obj: - return obj["internal_url"] - raise MissingFieldError(f"Response '{obj}' missing 'internal_url'") - raise SpawnerError(r) + + # 409 (Conflict) indicates the user already has a running pod. See + # if it really is running, and if so, return its URL. + if r.status_code == 409: + return await self._get_internal_url() + else: + r.raise_for_status() + + # The spawn is now in progress. Monitor the events endpoint until + # we get a completion or failure event. + progress = 0 + timeout = timedelta(seconds=self.start_timeout) + async for sse in self._get_progress_events(timeout): + if sse.event == "progress": + try: + progress = int(sse.data) + except ValueError: + msg = "Invalid progress value: {sse.data}" + self.log.error(msg) + continue + event = SpawnEvent.from_sse(sse, progress) + self._events.append(event) + if event.complete: + break + if event.failed: + raise SpawnFailedError(event.message) + + # Return the internal URL of the spawned pod. + return await self._get_internal_url() + + except Exception: + # We see no end of problems caused by stranded half-created pods, + # so whenever anything goes wrong, try to delete anything we may + # have left behind before raising the fatal exception. + self.log.warning("Spawn failed, attempting to delete any remnants") + try: + await self.stop() + except Exception: + self.log.exception("Failed to delete lab after spawn failure") + raise async def stop(self) -> None: r = await self._client.delete( @@ -128,10 +271,11 @@ async def stop(self) -> None: timeout=300.0, headers=self._admin_authorization(), ) - if r.status_code == 202 or r.status_code == 404: - # We're deleting it, or it wasn't there to start with. + if r.status_code == 404: + # Nothing to delete, treat that as success. return - raise SpawnerError(r) + else: + r.raise_for_status() async def poll(self) -> Optional[int]: """ @@ -152,8 +296,8 @@ async def poll(self) -> Optional[int]: ) if r.status_code == 404: return 0 # No lab for user. - if r.status_code != 200: - raise SpawnerError(r) + else: + r.raise_for_status() result = r.json() if result["status"] == LabStatus.FAILED: return 1 @@ -165,44 +309,74 @@ async def options_form(self, spawner: Spawner) -> str: self._controller_url("lab-form", self.user.name), headers=await self._user_authorization(), ) - if r.status_code != 200: - raise SpawnerError(r) + r.raise_for_status() return r.text - async def progress(self) -> AsyncIterator[dict[str, bool | int | str]]: - progress = 0 - timeout = timedelta(seconds=self.start_timeout) - try: - async for sse in self._get_progress_events(timeout): - if sse.event == "complete": - yield { - "progress": 90, - "message": sse.data or "Lab pod running", - "ready": True, - } - return - elif sse.event == "progress": - try: - progress = int(sse.data) - except ValueError: - msg = "Invalid progress value: {sse.data}" - self.log.error(msg) - continue - elif sse.event in ("info", "error", "failed"): - if not sse.data: - continue - yield { - "progress": progress, - "message": sse.data, - "ready": False, - } - if sse.event == "failed": - return - else: - self.log.error(f"Unknown event type {sse.event}") - except TimeoutError: - msg = f"No update from event stream in {timeout}s, giving up" - self.log.error(msg) + async def progress(self) -> AsyncIterator[dict[str, int | str]]: + """Monitor the progress of a spawn. + + This method is the internal implementation of the progress API. It + provides an iterator of spawn events and then ends when the spawn + succeeds or fails. + + Yields + ------ + dict of str to str or int + Dictionary representing the event with fields ``progress``, + containing an integer completion percentage, and ``message``, + containing a human-readable description of the event. + + Notes + ----- + This method must never raise exceptions, since those will be treated + as unhandled exceptions by JupyterHub. If anything fails, just stop + the iterator. It doesn't do any HTTP calls itself, just monitors the + events created by `start`. + + Uses the internal ``_start_future`` attribute to track when the + related `start` method has completed. + """ + next_event = 0 + complete = False + + # Capture the current future and event stream in a local variable so + # that we consistently monitor the same invocation of start. If that + # one aborts and someone kicks off another one, we want to keep + # following the first one until it completes, not switch streams to + # the second one. + start_future = self._start_future + events = self._events + + # We were apparently called before start was called, so there's + # nothing to report. + if not start_future: + return + + while not complete: + if start_future.done(): + # Indicate that we're done, but continue to execute the rest + # of the loop. We want to process any events received before + # the spawner finishes and report them before ending the + # stream. + complete = True + + # This logic tries to ensure that we don't repeat events even + # though start will be adding more events while we're working. + # A new spawn replaces the events array when it starts, so grab a + # local reference so that we can finish processing it even if it's + # replaced while we work. + len_events = len(events) + for i in range(next_event, len_events): + yield events[i].to_dict() + next_event = len_events + + # This delay waiting for new events is obnoxious, and ideally we + # would do better with some sort of synchronization primitive, but + # there may be multiple invocations of progress watching the same + # invocation of start and this has the merits of simplicity. It's + # also what Kubespawner does. + if not complete: + await asyncio.sleep(1) def _controller_url(self, *components: str) -> str: """Build a URL to the Nublado lab controller. @@ -219,6 +393,19 @@ def _controller_url(self, *components: str) -> str: """ return self.controller_url + "/spawner/v1/" + "/".join(components) + async def _get_internal_url(self) -> str: + """Get the cluster-internal URL of a user's pod.""" + r = await self._client.get( + self._controller_url("labs", self.user.name), + headers=self._admin_authorization(), + ) + r.raise_for_status() + url = r.json().get("internal_url") + if not url: + msg = f"Invalid lab status for {self.user.name}" + raise MissingFieldError(msg) + return url + async def _get_progress_events( self, timeout: timedelta ) -> AsyncIterator[ServerSentEvent]: diff --git a/tests/spawner_test.py b/tests/spawner_test.py index edeffe3..5ad896c 100644 --- a/tests/spawner_test.py +++ b/tests/spawner_test.py @@ -59,13 +59,13 @@ async def test_options_form(spawner: RSPRestSpawner) -> None: @pytest.mark.asyncio async def test_progress(spawner: RSPRestSpawner) -> None: await spawner.start() + user = spawner.user.name expected = [ - {"progress": 2, "message": "Lab creation initiated", "ready": False}, - {"progress": 45, "message": "Pod requested", "ready": False}, + {"progress": 2, "message": "[info] Lab creation initiated"}, + {"progress": 45, "message": "[info] Pod requested"}, { "progress": 90, - "message": f"Pod successfully spawned for {spawner.user.name}", - "ready": True, + "message": f"[info] Pod successfully spawned for {user}", }, ] index = 0 diff --git a/tests/support/controller.py b/tests/support/controller.py index 1c6394a..e8311da 100644 --- a/tests/support/controller.py +++ b/tests/support/controller.py @@ -88,7 +88,7 @@ def create(self, request: Request, user: str) -> Response: return Response(status_code=409) self._lab_status[user] = LabStatus.RUNNING location = f"{self._url}/{user}" - return Response(status_code=303, headers={"Location": location}) + return Response(status_code=201, headers={"Location": location}) def delete(self, request: Request, user: str) -> Response: if self._lab_status.get(user): diff --git a/tox.ini b/tox.ini index 3988a93..970f612 100644 --- a/tox.ini +++ b/tox.ini @@ -13,7 +13,7 @@ deps = -r{toxinidir}/requirements/main.txt -r{toxinidir}/requirements/dev.txt commands = - pytest --cov=rsp_restspawner --cov-branch --cov-report= {posargs} + pytest -vv --cov=rsp_restspawner --cov-branch --cov-report= {posargs} setenv = {[base]setenv} From 9027efa76e4448e090ec2b4cb932cf44cf052bbe Mon Sep 17 00:00:00 2001 From: Russ Allbery Date: Fri, 24 Mar 2023 08:38:03 -0700 Subject: [PATCH 2/6] Test the spawner sends the right tokens Add tests to the controller mock to ensure that the spawner sends the correct tokens to the different routes. --- tests/conftest.py | 8 +++++++- tests/support/controller.py | 38 ++++++++++++++++++++++++++++++++++--- 2 files changed, 42 insertions(+), 4 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index c371824..71cff3a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,7 +16,13 @@ @pytest.fixture def mock_lab_controller(respx_mock: respx.Router) -> MockLabController: url = "https://rsp.example.org/nublado" - return register_mock_lab_controller(respx_mock, url) + admin_token = (Path(__file__).parent / "data" / "admin-token").read_text() + return register_mock_lab_controller( + respx_mock, + url, + user_token="token-of-affection", + admin_token=admin_token.strip(), + ) @pytest.fixture diff --git a/tests/support/controller.py b/tests/support/controller.py index e8311da..4d1116a 100644 --- a/tests/support/controller.py +++ b/tests/support/controller.py @@ -71,6 +71,10 @@ class MockLabController: ---------- base_url Base URL with which the mock was configured. + user_token + User token expected for routes requiring user authentication. + admin_token + JupyterHub token expected for routes only it can use. Parameters ---------- @@ -78,12 +82,17 @@ class MockLabController: Base URL where the mock is installed, used for constructing redirects. """ - def __init__(self, base_url: str) -> None: + def __init__( + self, base_url: str, user_token: str, admin_token: str + ) -> None: self.base_url = base_url + self._user_token = user_token + self._admin_token = admin_token self._url = f"{base_url}/spawner/v1" self._lab_status: dict[str, LabStatus] = {} def create(self, request: Request, user: str) -> Response: + self._check_authorization(request) if self._lab_status.get(user): return Response(status_code=409) self._lab_status[user] = LabStatus.RUNNING @@ -91,6 +100,7 @@ def create(self, request: Request, user: str) -> Response: return Response(status_code=201, headers={"Location": location}) def delete(self, request: Request, user: str) -> Response: + self._check_authorization(request, admin=True) if self._lab_status.get(user): del self._lab_status[user] return Response(status_code=202) @@ -98,6 +108,7 @@ def delete(self, request: Request, user: str) -> Response: return Response(status_code=404) def events(self, request: Request, user: str) -> Response: + self._check_authorization(request) if not self._lab_status.get(user): return Response(status_code=404) stream = MockProgress(user) @@ -108,6 +119,7 @@ def events(self, request: Request, user: str) -> Response: ) def lab_form(self, request: Request, user: str) -> Response: + self._check_authorization(request) return Response( status_code=200, text=f"

This is some lab form for {user}

" ) @@ -117,6 +129,7 @@ def set_status(self, user: str, status: LabStatus) -> None: self._lab_status[user] = status def status(self, request: Request, user: str) -> Response: + self._check_authorization(request, admin=True) if not self._lab_status.get(user): return Response(status_code=404) return Response( @@ -127,9 +140,24 @@ def status(self, request: Request, user: str) -> Response: }, ) + def _check_authorization( + self, request: Request, admin: bool = False + ) -> None: + authorization = request.headers["Authorization"] + auth_type, token = authorization.split(None, 1) + assert auth_type.lower() == "bearer" + if admin: + assert token == self._admin_token + else: + assert token == self._user_token + def register_mock_lab_controller( - respx_mock: respx.Router, base_url: str + respx_mock: respx.Router, + base_url: str, + *, + user_token: str, + admin_token: str, ) -> MockLabController: """Mock out a Nublado lab controller. @@ -139,6 +167,10 @@ def register_mock_lab_controller( Mock router. base_url Base URL for the lab controller. + user_token + User token expected for routes requiring user authentication. + admin_token + JupyterHub token expected for routes only it can use. Returns ------- @@ -151,7 +183,7 @@ def register_mock_lab_controller( events_url = f"{base_labs_url}/events$" lab_form_url = f"{base_url}/spawner/v1/lab-form/(?P[^/]*)$" - mock = MockLabController(base_url) + mock = MockLabController(base_url, user_token, admin_token) respx_mock.get(url__regex=lab_url).mock(side_effect=mock.status) respx_mock.delete(url__regex=lab_url).mock(side_effect=mock.delete) respx_mock.post(url__regex=create_url).mock(side_effect=mock.create) From 0702be4d75c205adcb7f2c38a0911669e6de8b32 Mon Sep 17 00:00:00 2001 From: Russ Allbery Date: Fri, 24 Mar 2023 14:10:31 -0700 Subject: [PATCH 3/6] Add more tests for spawning and progress Test failures during spawn, delays in returning the progress messages, and multiple event watchers. --- tests/spawner_test.py | 67 +++++++++++++++++++++++++++++++++++++ tests/support/controller.py | 46 ++++++++++++++++++------- 2 files changed, 100 insertions(+), 13 deletions(-) diff --git a/tests/spawner_test.py b/tests/spawner_test.py index 5ad896c..c4b57d9 100644 --- a/tests/spawner_test.py +++ b/tests/spawner_test.py @@ -2,13 +2,27 @@ from __future__ import annotations +import asyncio +from datetime import timedelta + import pytest +from rsp_restspawner.errors import SpawnFailedError from rsp_restspawner.spawner import LabStatus, RSPRestSpawner from .support.controller import MockLabController +async def gather_progress( + spawner: RSPRestSpawner, +) -> list[dict[str, int | str]]: + """Gather progress from a spawner and return it as a list when done.""" + result = [] + async for message in spawner.progress(): + result.append(message) + return result + + @pytest.mark.asyncio async def test_start(spawner: RSPRestSpawner) -> None: user = spawner.user.name @@ -73,3 +87,56 @@ async def test_progress(spawner: RSPRestSpawner) -> None: assert message == expected[index] index += 1 assert index == len(expected) + + +@pytest.mark.asyncio +async def test_progress_multiple( + spawner: RSPRestSpawner, mock_lab_controller: MockLabController +) -> None: + """Test multiple progress listeners for the same spawn.""" + mock_lab_controller.delay = timedelta(milliseconds=750) + user = spawner.user.name + expected = [ + {"progress": 2, "message": "[info] Lab creation initiated"}, + {"progress": 45, "message": "[info] Pod requested"}, + { + "progress": 90, + "message": f"[info] Pod successfully spawned for {user}", + }, + ] + + results = await asyncio.gather( + spawner.start(), + gather_progress(spawner), + gather_progress(spawner), + gather_progress(spawner), + ) + url = results.pop(0) + assert url == f"http://lab.nublado-{user}:8888" + for events in results: + assert events == expected + + +@pytest.mark.asyncio +async def test_spawn_failure( + spawner: RSPRestSpawner, mock_lab_controller: MockLabController +) -> None: + """Test error handling when a spawn fails.""" + mock_lab_controller.delay = timedelta(milliseconds=750) + mock_lab_controller.fail_during_spawn = True + user = spawner.user.name + expected = [ + {"progress": 2, "message": "[info] Lab creation initiated"}, + {"progress": 45, "message": "[info] Pod requested"}, + {"progress": 45, "message": "[error] Something is going wrong"}, + { + "progress": 45, + "message": f"[error] Some random failure for {user}", + }, + ] + + results = await asyncio.gather( + spawner.start(), gather_progress(spawner), return_exceptions=True + ) + assert isinstance(results[0], SpawnFailedError) + assert results[1] == expected diff --git a/tests/support/controller.py b/tests/support/controller.py index 4d1116a..b04eb74 100644 --- a/tests/support/controller.py +++ b/tests/support/controller.py @@ -5,7 +5,6 @@ import asyncio from collections.abc import AsyncIterator from datetime import timedelta -from typing import Optional import respx from httpx import AsyncByteStream, Request, Response @@ -30,11 +29,16 @@ class MockProgress(AsyncByteStream): Name of user for which progress events should be generated. delay Delay by this long between events. + fail_during_spawn + Whether to emit a failure message instead of a completion message. """ - def __init__(self, user: str, delay: Optional[timedelta] = None) -> None: + def __init__( + self, user: str, delay: timedelta, fail_during_spawn: bool = False + ) -> None: self._user = user - self._delay = delay if delay else timedelta(seconds=0) + self._delay = delay + self._fail_during_spawn = fail_during_spawn async def __aiter__(self) -> AsyncIterator[bytes]: yield b"event: progress\r\n" @@ -55,10 +59,19 @@ async def __aiter__(self) -> AsyncIterator[bytes]: await asyncio.sleep(self._delay.total_seconds()) - yield b"event: complete\r\n" - msg = f"Pod successfully spawned for {self._user}" - yield b"data: " + msg.encode() + b"\r\n" - yield b"\r\n" + if self._fail_during_spawn: + yield b"event: error\r\n" + yield b"data: Something is going wrong\r\n" + yield b"\r\n" + yield b"event: failed\r\n" + msg = f"Some random failure for {self._user}" + yield b"data: " + msg.encode() + b"\r\n" + yield b"\r\n" + else: + yield b"event: complete\r\n" + msg = f"Pod successfully spawned for {self._user}" + yield b"data: " + msg.encode() + b"\r\n" + yield b"\r\n" class MockLabController: @@ -71,21 +84,25 @@ class MockLabController: ---------- base_url Base URL with which the mock was configured. - user_token - User token expected for routes requiring user authentication. - admin_token - JupyterHub token expected for routes only it can use. + delay + Set this to the desired delay between server-sent events. Parameters ---------- base_url Base URL where the mock is installed, used for constructing redirects. + user_token + User token expected for routes requiring user authentication. + admin_token + JupyterHub token expected for routes only it can use. """ def __init__( self, base_url: str, user_token: str, admin_token: str ) -> None: self.base_url = base_url + self.delay = timedelta(seconds=0) + self.fail_during_spawn = False self._user_token = user_token self._admin_token = admin_token self._url = f"{base_url}/spawner/v1" @@ -95,7 +112,10 @@ def create(self, request: Request, user: str) -> Response: self._check_authorization(request) if self._lab_status.get(user): return Response(status_code=409) - self._lab_status[user] = LabStatus.RUNNING + if self.fail_during_spawn: + self._lab_status[user] = LabStatus.FAILED + else: + self._lab_status[user] = LabStatus.RUNNING location = f"{self._url}/{user}" return Response(status_code=201, headers={"Location": location}) @@ -111,7 +131,7 @@ def events(self, request: Request, user: str) -> Response: self._check_authorization(request) if not self._lab_status.get(user): return Response(status_code=404) - stream = MockProgress(user) + stream = MockProgress(user, self.delay, self.fail_during_spawn) return Response( status_code=200, headers={"Content-Type": "text/event-stream"}, From a37a7e7d17e4439170b35d4f8d36e9686a0efcc1 Mon Sep 17 00:00:00 2001 From: Russ Allbery Date: Fri, 24 Mar 2023 14:51:01 -0700 Subject: [PATCH 4/6] Improve progress messages and debugging Be kind to our future selves and return more details in unusual spawn situations via the event stream. Every bit of debugging information we can see is helpful. Separate the severity from the message for spawn events for clearer construction. Flesh out the docstrings and document which exceptions may be raised. --- src/rsp_restspawner/spawner.py | 140 +++++++++++++++++++++++++++------ tests/spawner_test.py | 13 ++- 2 files changed, 129 insertions(+), 24 deletions(-) diff --git a/src/rsp_restspawner/spawner.py b/src/rsp_restspawner/spawner.py index 7755c50..2e1fd21 100644 --- a/src/rsp_restspawner/spawner.py +++ b/src/rsp_restspawner/spawner.py @@ -49,6 +49,9 @@ class SpawnEvent: message: str """Event description.""" + severity: str + """Log message severity.""" + complete: bool = False """Whether the event indicated spawning is done.""" @@ -68,21 +71,29 @@ def from_sse(cls, sse: ServerSentEvent, progress: int) -> SpawnEvent: communicate this must be done outside of this class. """ if sse.event == "complete": - message = "[info] " + (sse.data or "Lab pod successfully spawned") - return cls(progress=90, message=message, complete=True) + return cls( + progress=90, message=sse.data, severity="info", complete=True + ) elif sse.event == "info": - return cls(progress=progress, message=f"[info] {sse.data}") + return cls(progress=progress, message=sse.data, severity="info") elif sse.event == "error": - return cls(progress=progress, message=f"[error] {sse.data}") + return cls(progress=progress, message=sse.data, severity="error") elif sse.event == "failed": - message = f"[error] {sse.data}" - return cls(progress=progress, message=message, failed=True) + return cls( + progress=progress, + message=sse.data, + severity="error", + failed=True, + ) else: - return cls(progress=progress, message=f"[unknown] {sse.data}") + return cls(progress=progress, message=sse.data, severity="unknown") def to_dict(self) -> dict[str, int | str]: """Convert to the dictionary expected by JupyterHub.""" - return {"progress": self.progress, "message": self.message} + return { + "progress": self.progress, + "message": f"[{self.severity}] {self.message}", + } class RSPRestSpawner(Spawner): @@ -146,7 +157,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: # Holds the future representing a spawn in progress, used by the # progress method to know when th spawn is finished and it should # exit. - self._start_future: Optional[asyncio.Task] = None + self._start_future: Optional[asyncio.Task[str]] = None @property def _client(self) -> AsyncClient: @@ -156,7 +167,7 @@ def _client(self) -> AsyncClient: _CLIENT = AsyncClient() return _CLIENT - def start(self) -> asyncio.Task: + def start(self) -> asyncio.Task[str]: """Start the user's pod. Initiates the pod start operation and then waits for the pod to spawn @@ -208,12 +219,27 @@ async def _start(self) -> str: str Cluster-internal URL of the running Jupyter lab process. + Raises + ------ + httpx.HTTPError + Raised on failure to talk to the lab controller or a failure + response from the lab controller. + InvalidAuthStateError + Raised if there is no ``token`` attribute in the user's + authentication state. This should always be provided by + `~rsp_restspawner.auth.GafaelfawrAuthenticator`. + MissingFieldError + Raised if the response from the lab controller is invalid. + SpawnFailedError + Raised if the lab controller said that the spawn failed. + Notes ----- JupyterHub itself arranges for two spawns for the same spawner object to not be running at the same time, so we ignore that possibility. """ self._events = [] + progress = 0 try: r = await self._client.post( self._controller_url("labs", self.user.name, "create"), @@ -228,13 +254,16 @@ async def _start(self) -> str: # 409 (Conflict) indicates the user already has a running pod. See # if it really is running, and if so, return its URL. if r.status_code == 409: + event = SpawnEvent( + progress=90, message="Lab already running", severity="info" + ) + self._events.append(event) return await self._get_internal_url() else: r.raise_for_status() # The spawn is now in progress. Monitor the events endpoint until # we get a completion or failure event. - progress = 0 timeout = timedelta(seconds=self.start_timeout) async for sse in self._get_progress_events(timeout): if sse.event == "progress": @@ -259,13 +288,36 @@ async def _start(self) -> str: # so whenever anything goes wrong, try to delete anything we may # have left behind before raising the fatal exception. self.log.warning("Spawn failed, attempting to delete any remnants") + event = SpawnEvent( + progress=progress, + message="Lab creation failed, attempting to clean up", + severity="warning", + ) + self._events.append(event) try: await self.stop() - except Exception: + except Exception as e: self.log.exception("Failed to delete lab after spawn failure") + error = f"{type(e).__name__}: {str(e)}" + event = SpawnEvent( + progress=progress, + message=f"Failed to clean up failed lab: {error}", + severity="error", + ) + self._events.append(event) raise async def stop(self) -> None: + """Delete any running pod for the user. + + If the pod does not exist, treat that as success. + + Raises + ------ + httpx.HTTPError + Raised on failure to talk to the lab controller or a failure + response from the lab controller. + """ r = await self._client.delete( self._controller_url("labs", self.user.name), timeout=300.0, @@ -278,17 +330,24 @@ async def stop(self) -> None: r.raise_for_status() async def poll(self) -> Optional[int]: - """ - Check if the pod is running. + """Check if the pod is running. - If it is, return None. If it has exited, return the return code - if we know it, or 0 if it exited but we don't know how. + Returns + ------- + int or None + If the pod is starting, running, or terminating, return `None`. + If the pod does not exist, return 0. If the pod exists in a failed + state, return 1. - Because we do not have direct access to the pod's exit code, we - are here going to return 0 for "The pod does not exist from the - perspective of the lab controller" (which assumes a good or unknown - exit status) and 1 for "We tried to start a pod, but it failed," which - implies a failure (i.e. non-zero) exit status. + Notes + ----- + In theory, this is supposed to be the exit status of the Jupyter lab + process. This isn't something we know in the classic sense since the + lab is a Kubernetes pod. We only know that something failed if the + record of the lab is hanging around in a failed state, so use a simple + non-zero exit status for that. Otherwise, we have no way to + distinguish between a pod that was shut down without error and a pod + that was stopped, so use an exit status of 0 in both cases. """ r = await self._client.get( self._controller_url("labs", self.user.name), @@ -305,6 +364,24 @@ async def poll(self) -> Optional[int]: return None async def options_form(self, spawner: Spawner) -> str: + """Retrieve the options form for this user from the lab controller. + + Parameters + ---------- + spawner + Another copy of the spawner (not used). It's not clear why + JupyterHub passes this into this method. + + Raises + ------ + httpx.HTTPError + Raised on failure to talk to the lab controller or a failure + response from the lab controller. + InvalidAuthStateError + Raised if there is no ``token`` attribute in the user's + authentication state. This should always be provided by + `~rsp_restspawner.auth.GafaelfawrAuthenticator`. + """ r = await self._client.get( self._controller_url("lab-form", self.user.name), headers=await self._user_authorization(), @@ -394,7 +471,16 @@ def _controller_url(self, *components: str) -> str: return self.controller_url + "/spawner/v1/" + "/".join(components) async def _get_internal_url(self) -> str: - """Get the cluster-internal URL of a user's pod.""" + """Get the cluster-internal URL of a user's pod. + + Raises + ------ + httpx.HTTPError + Raised on failure to talk to the lab controller or a failure + response from the lab controller. + MissingFieldError + Raised if the response from the lab controller is invalid. + """ r = await self._client.get( self._controller_url("labs", self.user.name), headers=self._admin_authorization(), @@ -420,6 +506,16 @@ async def _get_progress_events( ------ ServerSentEvent Next event from the lab controller's event stream. + + Raises + ------ + httpx.HTTPError + Raised on failure to talk to the lab controller or a failure + response from the lab controller. + InvalidAuthStateError + Raised if there is no ``token`` attribute in the user's + authentication state. This should always be provided by + `~rsp_restspawner.auth.GafaelfawrAuthenticator`. """ url = self._controller_url("labs", self.user.name, "events") kwargs = { diff --git a/tests/spawner_test.py b/tests/spawner_test.py index c4b57d9..4dd12b1 100644 --- a/tests/spawner_test.py +++ b/tests/spawner_test.py @@ -111,9 +111,9 @@ async def test_progress_multiple( gather_progress(spawner), gather_progress(spawner), ) - url = results.pop(0) + url = results[0] assert url == f"http://lab.nublado-{user}:8888" - for events in results: + for events in results[1:]: assert events == expected @@ -133,6 +133,10 @@ async def test_spawn_failure( "progress": 45, "message": f"[error] Some random failure for {user}", }, + { + "progress": 45, + "message": "[warning] Lab creation failed, attempting to clean up", + }, ] results = await asyncio.gather( @@ -140,3 +144,8 @@ async def test_spawn_failure( ) assert isinstance(results[0], SpawnFailedError) assert results[1] == expected + + # Because the spawn failed, we should have tried to shut down the lab. It + # therefore should not exist, rather than hanging around in a failed + # state. + assert await spawner.poll() == 0 From 0e4b8bb3849c3d6368888510c42f4e6161c1ca7b Mon Sep 17 00:00:00 2001 From: Russ Allbery Date: Wed, 29 Mar 2023 12:14:32 -0700 Subject: [PATCH 5/6] Improve some comments in progress Some of the comments in the progress method were out of date or incomplete. Flesh them out. --- src/rsp_restspawner/spawner.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/rsp_restspawner/spawner.py b/src/rsp_restspawner/spawner.py index 2e1fd21..80bc75b 100644 --- a/src/rsp_restspawner/spawner.py +++ b/src/rsp_restspawner/spawner.py @@ -439,19 +439,17 @@ async def progress(self) -> AsyncIterator[dict[str, int | str]]: # This logic tries to ensure that we don't repeat events even # though start will be adding more events while we're working. - # A new spawn replaces the events array when it starts, so grab a - # local reference so that we can finish processing it even if it's - # replaced while we work. len_events = len(events) for i in range(next_event, len_events): yield events[i].to_dict() next_event = len_events # This delay waiting for new events is obnoxious, and ideally we - # would do better with some sort of synchronization primitive, but - # there may be multiple invocations of progress watching the same - # invocation of start and this has the merits of simplicity. It's - # also what Kubespawner does. + # would do better with some sort of synchronization primitive. + # Using an asyncio.Event per progress invocation would work if + # JupyterHub is always asyncio, but I wasn't sure if it used + # thread pools and asyncio synchronization primitives are not + # thread-safe. The delay approach is what KubeSpawner does. if not complete: await asyncio.sleep(1) From 29c6bc17206ece76feece6e0900c65c2890ae9f1 Mon Sep 17 00:00:00 2001 From: Russ Allbery Date: Wed, 29 Mar 2023 12:56:55 -0700 Subject: [PATCH 6/6] Rename gather_progress to collect_progress Avoid confusion with asyncio.gather. --- tests/spawner_test.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/spawner_test.py b/tests/spawner_test.py index 4dd12b1..399d8a4 100644 --- a/tests/spawner_test.py +++ b/tests/spawner_test.py @@ -13,7 +13,7 @@ from .support.controller import MockLabController -async def gather_progress( +async def collect_progress( spawner: RSPRestSpawner, ) -> list[dict[str, int | str]]: """Gather progress from a spawner and return it as a list when done.""" @@ -107,9 +107,9 @@ async def test_progress_multiple( results = await asyncio.gather( spawner.start(), - gather_progress(spawner), - gather_progress(spawner), - gather_progress(spawner), + collect_progress(spawner), + collect_progress(spawner), + collect_progress(spawner), ) url = results[0] assert url == f"http://lab.nublado-{user}:8888" @@ -140,7 +140,7 @@ async def test_spawn_failure( ] results = await asyncio.gather( - spawner.start(), gather_progress(spawner), return_exceptions=True + spawner.start(), collect_progress(spawner), return_exceptions=True ) assert isinstance(results[0], SpawnFailedError) assert results[1] == expected