From 2b1249b2f49372e58a825918137843b344bdcff8 Mon Sep 17 00:00:00 2001 From: Russ Allbery Date: Thu, 16 Mar 2023 09:51:34 -0700 Subject: [PATCH 1/2] Update mypy configuration, fix issues Update to the current mypy configuration and fix various issues it uncovered. This unfortunately includes importing git module attributes from submodules since the way that GitPython dynamically creates its __all__ variable at the top level is incompatible with mypy. --- pyproject.toml | 8 ++++++++ src/mobu/business/base.py | 1 - src/mobu/business/notebookrunner.py | 6 +++--- src/mobu/exceptions.py | 13 +++++++++---- src/mobu/models/user.py | 4 ++-- tests/business/notebookrunner_test.py | 3 ++- tests/timings_test.py | 10 +++++----- 7 files changed, 29 insertions(+), 16 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 702fa5e0..2cd48edb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -101,9 +101,17 @@ python_files = ["tests/*.py", "tests/*/*.py"] disallow_untyped_defs = true disallow_incomplete_defs = true ignore_missing_imports = true +local_partial_types = true plugins = ["pydantic.mypy"] +no_implicit_reexport = true show_error_codes = true strict_equality = true warn_redundant_casts = true warn_unreachable = true warn_unused_ignores = true + +[tool.pydantic-mypy] +init_forbid_extra = true +init_typed = true +warn_required_dynamic_aliases = true +warn_untyped_fields = true diff --git a/src/mobu/business/base.py b/src/mobu/business/base.py index 11346d64..5f31c3d8 100644 --- a/src/mobu/business/base.py +++ b/src/mobu/business/base.py @@ -194,7 +194,6 @@ async def iter_next() -> T: def dump(self) -> BusinessData: return BusinessData( name=type(self).__name__, - config=self.config, failure_count=self.failure_count, success_count=self.success_count, timings=self.timings.dump(), diff --git a/src/mobu/business/notebookrunner.py b/src/mobu/business/notebookrunner.py index f03bdb92..8461a52c 100644 --- a/src/mobu/business/notebookrunner.py +++ b/src/mobu/business/notebookrunner.py @@ -12,7 +12,7 @@ from tempfile import TemporaryDirectory from typing import Any, Dict, List, Optional -import git +from git.repo import Repo from structlog import BoundLogger from ..exceptions import NotebookRepositoryError @@ -37,7 +37,7 @@ def __init__( self.notebook: Optional[Path] = None self.running_code: Optional[str] = None self._repo_dir = TemporaryDirectory() - self._repo: Optional[git.Repo] = None + self._repo: Optional[Repo] = None self._notebook_paths: Optional[List[Path]] = None def annotations(self) -> Dict[str, str]: @@ -58,7 +58,7 @@ def clone_repo(self) -> None: branch = self.config.repo_branch path = self._repo_dir.name with self.timings.start("clone_repo"): - self._repo = git.Repo.clone_from(url, path, branch=branch) + self._repo = Repo.clone_from(url, path, branch=branch) def find_notebooks(self) -> List[Path]: with self.timings.start("find_notebooks"): diff --git a/src/mobu/exceptions.py b/src/mobu/exceptions.py index 326f4950..c18d8a87 100644 --- a/src/mobu/exceptions.py +++ b/src/mobu/exceptions.py @@ -8,6 +8,8 @@ from aiohttp import ClientResponse, ClientResponseError from safir.datetime import format_datetime_for_logging from safir.slack.blockkit import ( + SlackBaseBlock, + SlackBaseField, SlackCodeBlock, SlackException, SlackMessage, @@ -73,13 +75,14 @@ def to_slack(self) -> SlackMessage: """ return SlackMessage(message=str(self), fields=self.common_fields()) - def common_fields(self) -> list[SlackTextField]: + def common_fields(self) -> list[SlackBaseField]: """Return common fields to put in any alert.""" failed_at = format_datetime_for_logging(self.failed_at) - fields = [ + fields: list[SlackBaseField] = [ SlackTextField(heading="Failed at", text=failed_at), - SlackTextField(heading="User", text=self.user), ] + if self.user: + fields.append(SlackTextField(heading="User", text=self.user)) if self.started_at: started_at = format_datetime_for_logging(self.started_at) field = SlackTextField(heading="Started at", text=started_at) @@ -149,7 +152,9 @@ def to_slack(self) -> SlackMessage: if self.status: intro += f" (status: {self.status})" - attachments = [SlackCodeBlock(heading="Code executed", code=self.code)] + attachments: list[SlackBaseBlock] = [ + SlackCodeBlock(heading="Code executed", code=self.code) + ] if self.error: attachment = SlackCodeBlock(heading="Error", code=self.error) attachments.insert(0, attachment) diff --git a/src/mobu/models/user.py b/src/mobu/models/user.py index fc47073d..966bce77 100644 --- a/src/mobu/models/user.py +++ b/src/mobu/models/user.py @@ -1,7 +1,7 @@ """Data models for an authenticated user.""" import time -from typing import List, Optional +from typing import Any, List, Optional from aiohttp import ClientSession from pydantic import BaseModel, Field @@ -94,7 +94,7 @@ async def create( cls, user: User, scopes: List[str], session: ClientSession ) -> "AuthenticatedUser": token_url = f"{config.environment_url}/auth/api/v1/tokens" - data = { + data: dict[str, Any] = { "username": user.username, "name": "Mobu Test User", "token_type": "user", diff --git a/tests/business/notebookrunner_test.py b/tests/business/notebookrunner_test.py index 1a064484..8ff9656c 100644 --- a/tests/business/notebookrunner_test.py +++ b/tests/business/notebookrunner_test.py @@ -9,7 +9,8 @@ import pytest from aioresponses import aioresponses -from git import Actor, Repo +from git.repo import Repo +from git.util import Actor from httpx import AsyncClient from safir.testing.slack import MockSlackWebhook diff --git a/tests/timings_test.py b/tests/timings_test.py index 7f4737f1..94594132 100644 --- a/tests/timings_test.py +++ b/tests/timings_test.py @@ -44,16 +44,16 @@ def test_timings() -> None: StopwatchData( event="something", annotations={}, - start=first_sw.start_time.isoformat(), - stop=first_sw.stop_time.isoformat(), + start=first_sw.start_time, + stop=first_sw.stop_time, elapsed=first_sw.elapsed.total_seconds(), failed=False, ), StopwatchData( event="else", annotations={"foo": "bar"}, - start=second_sw.start_time.isoformat(), - stop=second_sw.stop_time.isoformat(), + start=second_sw.start_time, + stop=second_sw.stop_time, elapsed=second_sw.elapsed.total_seconds(), failed=True, ), @@ -64,7 +64,7 @@ def test_timings() -> None: assert dump[2] == StopwatchData( event="incomplete", annotations={}, - start=sw.start_time.isoformat(), + start=sw.start_time, stop=None, elapsed=None, ) From 1d7b9ce2024d4fb0ee677fe90145d56c6d16fcf5 Mon Sep 17 00:00:00 2001 From: Russ Allbery Date: Thu, 16 Mar 2023 11:58:36 -0700 Subject: [PATCH 2/2] Update type annotations Follow the PEP-585 recommendations to stop importing most types from the typing module. Start using Self types where appropriate. Stop using Optional in places where SQR-074 recommends | None instead, and use the | syntax instead of Union. --- src/mobu/business/base.py | 5 +++-- src/mobu/business/jupyterloginloop.py | 4 ++-- src/mobu/business/jupyterpythonloop.py | 4 ++-- src/mobu/business/notebookrunner.py | 10 +++++----- src/mobu/business/tapqueryrunner.py | 4 ++-- src/mobu/cachemachine.py | 4 +--- src/mobu/config.py | 7 +++---- src/mobu/dependencies/manager.py | 8 ++++---- src/mobu/exceptions.py | 22 ++++++++-------------- src/mobu/jupyterclient.py | 14 +++----------- src/mobu/models/business.py | 4 ++-- src/mobu/models/flock.py | 16 ++++++++-------- src/mobu/models/jupyter.py | 14 ++++++++------ src/mobu/models/timings.py | 4 ++-- src/mobu/models/user.py | 10 ++++++---- src/mobu/monkey.py | 4 ++-- src/mobu/timings.py | 16 ++++++++-------- src/mobu/util.py | 5 +++-- src/monkeyflocker/cli.py | 5 ++--- src/monkeyflocker/client.py | 8 ++++---- tests/autostart_test.py | 2 +- tests/conftest.py | 3 ++- tests/handlers/flock_test.py | 6 +++--- tests/monkeyflocker_test.py | 7 ++++--- tests/support/jupyter.py | 20 ++++++++++---------- tests/support/util.py | 4 ++-- 26 files changed, 100 insertions(+), 110 deletions(-) diff --git a/src/mobu/business/base.py b/src/mobu/business/base.py index 5f31c3d8..ec350344 100644 --- a/src/mobu/business/base.py +++ b/src/mobu/business/base.py @@ -3,10 +3,11 @@ from __future__ import annotations import asyncio -from asyncio import Queue, QueueEmpty, TimeoutError +from asyncio import Queue, QueueEmpty +from collections.abc import AsyncIterable, AsyncIterator from datetime import datetime, timezone from enum import Enum -from typing import AsyncIterable, AsyncIterator, TypeVar +from typing import TypeVar from structlog import BoundLogger diff --git a/src/mobu/business/jupyterloginloop.py b/src/mobu/business/jupyterloginloop.py index eba5580f..21e6a66f 100644 --- a/src/mobu/business/jupyterloginloop.py +++ b/src/mobu/business/jupyterloginloop.py @@ -9,7 +9,7 @@ import asyncio from dataclasses import dataclass, field from datetime import datetime, timezone -from typing import Dict, Optional +from typing import Optional from aiohttp import ClientError, ClientResponseError from structlog import BoundLogger @@ -75,7 +75,7 @@ def __init__( async def close(self) -> None: await self._client.close() - def annotations(self) -> Dict[str, str]: + def annotations(self) -> dict[str, str]: """Timer annotations to use. Subclasses should override this to add more annotations based on diff --git a/src/mobu/business/jupyterpythonloop.py b/src/mobu/business/jupyterpythonloop.py index 8f10c626..3dec84b9 100644 --- a/src/mobu/business/jupyterpythonloop.py +++ b/src/mobu/business/jupyterpythonloop.py @@ -6,7 +6,7 @@ from __future__ import annotations -from typing import Dict, Optional +from typing import Optional from structlog import BoundLogger @@ -46,7 +46,7 @@ def __init__( super().__init__(logger, business_config, user) self.node: Optional[str] = None - def annotations(self) -> Dict[str, str]: + def annotations(self) -> dict[str, str]: result = super().annotations() if self.node: result["node"] = self.node diff --git a/src/mobu/business/notebookrunner.py b/src/mobu/business/notebookrunner.py index 8461a52c..b294d2fc 100644 --- a/src/mobu/business/notebookrunner.py +++ b/src/mobu/business/notebookrunner.py @@ -10,7 +10,7 @@ import random from pathlib import Path from tempfile import TemporaryDirectory -from typing import Any, Dict, List, Optional +from typing import Any, Optional from git.repo import Repo from structlog import BoundLogger @@ -38,9 +38,9 @@ def __init__( self.running_code: Optional[str] = None self._repo_dir = TemporaryDirectory() self._repo: Optional[Repo] = None - self._notebook_paths: Optional[List[Path]] = None + self._notebook_paths: Optional[list[Path]] = None - def annotations(self) -> Dict[str, str]: + def annotations(self) -> dict[str, str]: result = super().annotations() if self.notebook: result["notebook"] = self.notebook.name @@ -60,7 +60,7 @@ def clone_repo(self) -> None: with self.timings.start("clone_repo"): self._repo = Repo.clone_from(url, path, branch=branch) - def find_notebooks(self) -> List[Path]: + def find_notebooks(self) -> list[Path]: with self.timings.start("find_notebooks"): notebooks = [ p @@ -79,7 +79,7 @@ def next_notebook(self) -> None: self._notebook_paths = self.find_notebooks() self.notebook = self._notebook_paths.pop() - def read_notebook(self, notebook: Path) -> List[Dict[str, Any]]: + def read_notebook(self, notebook: Path) -> list[dict[str, Any]]: with self.timings.start("read_notebook", {"notebook": notebook.name}): try: notebook_text = notebook.read_text() diff --git a/src/mobu/business/tapqueryrunner.py b/src/mobu/business/tapqueryrunner.py index 20e974f2..a3de6a27 100644 --- a/src/mobu/business/tapqueryrunner.py +++ b/src/mobu/business/tapqueryrunner.py @@ -7,7 +7,7 @@ import random from concurrent.futures import ThreadPoolExecutor from pathlib import Path -from typing import Dict, Optional, Union +from typing import Optional import jinja2 import pyvo @@ -87,7 +87,7 @@ def _generate_random_polygon( poly.append(dec + r * math.cos(theta)) return ", ".join([str(x) for x in poly]) - def _generate_parameters(self) -> Dict[str, Union[int, float, str]]: + def _generate_parameters(self) -> dict[str, int | float | str]: """Generate some random parameters for the query.""" min_ra = self._params.get("min_ra", 55.0) max_ra = self._params.get("max_ra", 70.0) diff --git a/src/mobu/cachemachine.py b/src/mobu/cachemachine.py index b43cdfe9..80332fff 100644 --- a/src/mobu/cachemachine.py +++ b/src/mobu/cachemachine.py @@ -2,8 +2,6 @@ from __future__ import annotations -from typing import List - from aiohttp import ClientSession from .config import config @@ -69,7 +67,7 @@ async def get_recommended(self) -> JupyterImage: raise CachemachineError(self._username, "No images found") return images[0] - async def _get_images(self) -> List[JupyterImage]: + async def _get_images(self) -> list[JupyterImage]: headers = {"Authorization": f"bearer {self._token}"} async with self._session.get(self._url, headers=headers) as r: if r.status != 200: diff --git a/src/mobu/config.py b/src/mobu/config.py index 7b22b590..9235b71f 100644 --- a/src/mobu/config.py +++ b/src/mobu/config.py @@ -4,7 +4,6 @@ import os from dataclasses import dataclass -from typing import Optional __all__ = ["Configuration", "config"] @@ -13,7 +12,7 @@ class Configuration: """Configuration for mobu.""" - alert_hook: Optional[str] = os.getenv("ALERT_HOOK") + alert_hook: str | None = os.getenv("ALERT_HOOK") """The slack webhook used for alerting exceptions to slack. Set with the ``ALERT_HOOK`` environment variable. @@ -21,7 +20,7 @@ class Configuration: If not set or set to "None", this feature will be disabled. """ - autostart: Optional[str] = os.getenv("AUTOSTART") + autostart: str | None = os.getenv("AUTOSTART") """The path to a YAML file defining what flocks to automatically start. The YAML file should, if given, be a list of flock specifications. All @@ -48,7 +47,7 @@ class Configuration: Set with the ``CACHEMACHINE_IMAGE_POLICY`` environment variable. """ - gafaelfawr_token: Optional[str] = os.getenv("GAFAELFAWR_TOKEN") + gafaelfawr_token: str | None = os.getenv("GAFAELFAWR_TOKEN") """The Gafaelfawr admin token to use to create user tokens. This token is used to make an admin API call to Gafaelfawr to get a token diff --git a/src/mobu/dependencies/manager.py b/src/mobu/dependencies/manager.py index 41e84776..9023b22b 100644 --- a/src/mobu/dependencies/manager.py +++ b/src/mobu/dependencies/manager.py @@ -3,7 +3,7 @@ from __future__ import annotations import asyncio -from typing import Dict, List, Optional +from typing import Optional from aiohttp import ClientSession from aiojobs import Scheduler @@ -19,7 +19,7 @@ class MonkeyBusinessManager: """Manages all of the running monkeys.""" def __init__(self) -> None: - self._flocks: Dict[str, Flock] = {} + self._flocks: dict[str, Flock] = {} self._scheduler: Optional[Scheduler] = None self._session: Optional[ClientSession] = None @@ -56,10 +56,10 @@ def get_flock(self, name: str) -> Flock: raise FlockNotFoundException(name) return flock - def list_flocks(self) -> List[str]: + def list_flocks(self) -> list[str]: return sorted(self._flocks.keys()) - def summarize_flocks(self) -> List[FlockSummary]: + def summarize_flocks(self) -> list[FlockSummary]: return [f.summary() for _, f in sorted(self._flocks.items())] async def stop_flock(self, name: str) -> None: diff --git a/src/mobu/exceptions.py b/src/mobu/exceptions.py index c18d8a87..0d76ebf4 100644 --- a/src/mobu/exceptions.py +++ b/src/mobu/exceptions.py @@ -3,7 +3,7 @@ from __future__ import annotations from datetime import datetime -from typing import Dict, Optional +from typing import Optional, Self from aiohttp import ClientResponse, ClientResponseError from safir.datetime import format_datetime_for_logging @@ -61,7 +61,7 @@ def __init__(self, user: str, msg: str) -> None: super().__init__(msg, user) self.started_at: Optional[datetime] = None self.event: Optional[str] = None - self.annotations: Dict[str, str] = {} + self.annotations: dict[str, str] = {} def to_slack(self) -> SlackMessage: """Format the error as a Slack Block Kit message. @@ -174,26 +174,22 @@ class JupyterResponseError(MobuSlackException): """Web response error from JupyterHub or JupyterLab.""" @classmethod - def from_exception( - cls, user: str, exc: ClientResponseError - ) -> JupyterResponseError: + def from_exception(cls, user: str, exc: ClientResponseError) -> Self: return cls( url=str(exc.request_info.url), user=user, status=exc.status, - reason=exc.message if exc.message else type(exc).__name__, + reason=exc.message or type(exc).__name__, method=exc.request_info.method, ) @classmethod - async def from_response( - cls, user: str, response: ClientResponse - ) -> JupyterResponseError: + async def from_response(cls, user: str, response: ClientResponse) -> Self: return cls( url=str(response.url), user=user, status=response.status, - reason=response.reason, + reason=response.reason or "", method=response.method, body=await response.text(), ) @@ -204,7 +200,7 @@ def __init__( url: str, user: str, status: int, - reason: Optional[str], + reason: str, method: str, body: Optional[str] = None, ) -> None: @@ -237,9 +233,7 @@ class JupyterSpawnError(MobuSlackException): """The Jupyter Lab pod failed to spawn.""" @classmethod - def from_exception( - cls, user: str, log: str, exc: Exception - ) -> JupyterSpawnError: + def from_exception(cls, user: str, log: str, exc: Exception) -> Self: return cls(user, log, f"{type(exc).__name__}: {str(exc)}") def __init__( diff --git a/src/mobu/jupyterclient.py b/src/mobu/jupyterclient.py index 73ef130b..dd368027 100644 --- a/src/mobu/jupyterclient.py +++ b/src/mobu/jupyterclient.py @@ -11,20 +11,12 @@ import random import re import string +from collections.abc import AsyncIterator, Awaitable, Callable from dataclasses import dataclass from datetime import datetime, timezone from functools import wraps from http.cookies import BaseCookie -from typing import ( - Any, - AsyncIterator, - Awaitable, - Callable, - Dict, - Optional, - TypeVar, - cast, -) +from typing import Any, Optional, TypeVar, cast from uuid import uuid4 from aiohttp import ( @@ -501,7 +493,7 @@ def _remove_ansi_escapes(string: str) -> str: """ return _ANSI_REGEX.sub("", string) - def _build_jupyter_spawn_form(self, image: JupyterImage) -> Dict[str, str]: + def _build_jupyter_spawn_form(self, image: JupyterImage) -> dict[str, str]: """Construct the form to submit to the JupyterHub login page.""" return { "image_list": str(image), diff --git a/src/mobu/models/business.py b/src/mobu/models/business.py index 1a416e68..b8c95494 100644 --- a/src/mobu/models/business.py +++ b/src/mobu/models/business.py @@ -1,6 +1,6 @@ """Models for monkey business.""" -from typing import List, Optional +from typing import Optional from pydantic import BaseModel, Field @@ -187,7 +187,7 @@ class BusinessData(BaseModel): success_count: int = Field(..., title="Number of successes", example=25) - timings: List[StopwatchData] = Field(..., title="Timings of events") + timings: list[StopwatchData] = Field(..., title="Timings of events") image: Optional[JupyterImage] = Field( None, diff --git a/src/mobu/models/flock.py b/src/mobu/models/flock.py index a509057e..ec80c95d 100644 --- a/src/mobu/models/flock.py +++ b/src/mobu/models/flock.py @@ -1,7 +1,7 @@ """Models for a collection of monkeys.""" from datetime import datetime -from typing import Any, Dict, List, Literal, Optional +from typing import Any, Literal, Optional from pydantic import BaseModel, Field, validator @@ -21,7 +21,7 @@ class FlockConfig(BaseModel): count: int = Field(..., title="How many monkeys to run", example=100) - users: Optional[List[User]] = Field( + users: Optional[list[User]] = Field( None, title="Explicit list of users to run as", description=( @@ -37,7 +37,7 @@ class FlockConfig(BaseModel): description="Specify either this or users but not both", ) - scopes: List[str] = Field( + scopes: list[str] = Field( ..., title="Token scopes", description="Must include all scopes required to run the business", @@ -63,8 +63,8 @@ class FlockConfig(BaseModel): @validator("users") def _valid_users( - cls, v: Optional[List[User]], values: Dict[str, Any] - ) -> Optional[List[User]]: + cls, v: list[User] | None, values: dict[str, Any] + ) -> list[User] | None: if v is None: return v if "count" in values and len(v) != values["count"]: @@ -74,8 +74,8 @@ def _valid_users( @validator("user_spec", always=True) def _valid_user_spec( - cls, v: Optional[UserSpec], values: Dict[str, Any] - ) -> Optional[UserSpec]: + cls, v: UserSpec | None, values: dict[str, Any] + ) -> UserSpec | None: if v is None and ("users" not in values or values["users"] is None): raise ValueError("one of users or user_spec must be provided") if v and "users" in values and values["users"]: @@ -99,7 +99,7 @@ class FlockData(BaseModel): config: FlockConfig = Field(..., title="Configuration for the flock") - monkeys: List[MonkeyData] = Field(..., title="Monkeys of the flock") + monkeys: list[MonkeyData] = Field(..., title="Monkeys of the flock") class FlockSummary(BaseModel): diff --git a/src/mobu/models/jupyter.py b/src/mobu/models/jupyter.py index ae830a6c..ba6b510d 100644 --- a/src/mobu/models/jupyter.py +++ b/src/mobu/models/jupyter.py @@ -1,7 +1,9 @@ """Models for configuring a Jupyter lab.""" +from __future__ import annotations + from enum import Enum -from typing import Dict, Optional +from typing import Optional, Self from pydantic import BaseModel, Field, validator @@ -45,15 +47,15 @@ def __str__(self) -> str: return "|".join([self.reference, self.name, self.digest or ""]) @classmethod - def from_dict(cls, data: Dict[str, str]) -> "JupyterImage": - return JupyterImage( + def from_dict(cls, data: dict[str, str]) -> Self: + return cls( reference=data["image_url"], name=data["name"], digest=data["image_hash"], ) @classmethod - def from_reference(cls, reference: str) -> "JupyterImage": + def from_reference(cls, reference: str) -> Self: return cls( reference=reference, name=reference.rsplit(":", 1)[1], digest="" ) @@ -90,8 +92,8 @@ class JupyterConfig(BaseModel): @validator("image_reference") def _valid_image_reference( - cls, v: Optional[str], values: Dict[str, object] - ) -> Optional[str]: + cls, v: str | None, values: dict[str, object] + ) -> str | None: if values.get("image_class") == JupyterImageClass.BY_REFERENCE: if not v: raise ValueError("image_reference required") diff --git a/src/mobu/models/timings.py b/src/mobu/models/timings.py index 046d29f1..3de9fee1 100644 --- a/src/mobu/models/timings.py +++ b/src/mobu/models/timings.py @@ -1,7 +1,7 @@ """Models for timing data.""" from datetime import datetime -from typing import Dict, Optional +from typing import Optional from pydantic import BaseModel, Field @@ -11,7 +11,7 @@ class StopwatchData(BaseModel): event: str = Field(..., title="Name of the event", example="lab_create") - annotations: Dict[str, str] = Field( + annotations: dict[str, str] = Field( default_factory=dict, title="Event annotations", example={"notebook": "example.ipynb"}, diff --git a/src/mobu/models/user.py b/src/mobu/models/user.py index 966bce77..3c0db9ae 100644 --- a/src/mobu/models/user.py +++ b/src/mobu/models/user.py @@ -1,7 +1,9 @@ """Data models for an authenticated user.""" +from __future__ import annotations + import time -from typing import Any, List, Optional +from typing import Any, Optional, Self from aiohttp import ClientSession from pydantic import BaseModel, Field @@ -77,7 +79,7 @@ class UserSpec(BaseModel): class AuthenticatedUser(User): """Represents an authenticated user with a token.""" - scopes: List[str] = Field( + scopes: list[str] = Field( ..., title="Token scopes", example=["exec:notebook", "read:tap"], @@ -91,8 +93,8 @@ class AuthenticatedUser(User): @classmethod async def create( - cls, user: User, scopes: List[str], session: ClientSession - ) -> "AuthenticatedUser": + cls, user: User, scopes: list[str], session: ClientSession + ) -> Self: token_url = f"{config.environment_url}/auth/api/v1/tokens" data: dict[str, Any] = { "username": user.username, diff --git a/src/mobu/monkey.py b/src/mobu/monkey.py index 5c54419d..ba87cf92 100644 --- a/src/mobu/monkey.py +++ b/src/mobu/monkey.py @@ -5,7 +5,7 @@ import logging import sys from tempfile import NamedTemporaryFile -from typing import Optional, Type +from typing import Optional import structlog from aiohttp import ClientSession @@ -31,7 +31,7 @@ class Monkey: def __init__( self, monkey_config: MonkeyConfig, - business_type: Type[Business], + business_type: type[Business], user: AuthenticatedUser, session: ClientSession, ): diff --git a/src/mobu/timings.py b/src/mobu/timings.py index 0f2bf460..649e5437 100644 --- a/src/mobu/timings.py +++ b/src/mobu/timings.py @@ -4,7 +4,7 @@ from datetime import datetime, timedelta, timezone from types import TracebackType -from typing import Dict, List, Literal, Optional +from typing import Literal, Optional from .exceptions import MobuSlackException from .models.timings import StopwatchData @@ -19,10 +19,10 @@ class Timings: def __init__(self) -> None: self._last: Optional[Stopwatch] = None - self._stopwatches: List[Stopwatch] = [] + self._stopwatches: list[Stopwatch] = [] def start( - self, event: str, annotations: Optional[Dict[str, str]] = None + self, event: str, annotations: Optional[dict[str, str]] = None ) -> Stopwatch: """Start a stopwatch. @@ -42,7 +42,7 @@ def start( self._last = stopwatch return stopwatch - def dump(self) -> List[StopwatchData]: + def dump(self) -> list[StopwatchData]: """Convert the stored timings to a dictionary.""" return [s.dump() for s in self._stopwatches] @@ -68,7 +68,7 @@ class Stopwatch: def __init__( self, event: str, - annotations: Dict[str, str], + annotations: dict[str, str], previous: Optional[Stopwatch] = None, ) -> None: self.event = event @@ -83,9 +83,9 @@ def __enter__(self) -> Stopwatch: def __exit__( self, - exc_type: Optional[type], - exc_val: Optional[Exception], - exc_tb: Optional[TracebackType], + exc_type: type | None, + exc_val: Exception | None, + exc_tb: TracebackType | None, ) -> Literal[False]: self.stop_time = datetime.now(tz=timezone.utc) if exc_val: diff --git a/src/mobu/util.py b/src/mobu/util.py index 3e40b649..6397c912 100644 --- a/src/mobu/util.py +++ b/src/mobu/util.py @@ -4,7 +4,8 @@ import asyncio from asyncio import Task -from typing import Awaitable, Callable, Coroutine, Optional, TypeVar +from collections.abc import Awaitable, Callable, Coroutine +from typing import TypeVar T = TypeVar("T") @@ -24,7 +25,7 @@ async def loop() -> None: return asyncio.ensure_future(loop()) -async def wait_first(*args: Coroutine[None, None, T]) -> Optional[T]: +async def wait_first(*args: Coroutine[None, None, T]) -> T | None: """Return the result of the first awaitable to finish. The other awaitables will be cancelled. The first awaitable determines diff --git a/src/monkeyflocker/cli.py b/src/monkeyflocker/cli.py index dbc84a16..7e5b6aea 100644 --- a/src/monkeyflocker/cli.py +++ b/src/monkeyflocker/cli.py @@ -3,7 +3,6 @@ from __future__ import annotations from pathlib import Path -from typing import Optional, Union import click from safir.asyncio import run_with_asyncio @@ -24,7 +23,7 @@ def main() -> None: @main.command() @click.argument("topic", default=None, required=False, nargs=1) @click.pass_context -def help(ctx: click.Context, topic: Union[None, str]) -> None: +def help(ctx: click.Context, topic: str | None) -> None: """Show help for any command.""" # The help command implementation is taken from # https://www.burgundywall.com/post/having-click-help-subcommand @@ -124,7 +123,7 @@ async def report(base_url: str, token: str, output: Path, name: str) -> None: @click.argument("name") @run_with_asyncio async def stop( - base_url: str, token: str, output: Optional[Path], name: str + base_url: str, token: str, output: Path | None, name: str ) -> None: """Stop a flock.""" async with MonkeyflockerClient(base_url, token) as client: diff --git a/src/monkeyflocker/client.py b/src/monkeyflocker/client.py index b4acb627..b2e216a6 100644 --- a/src/monkeyflocker/client.py +++ b/src/monkeyflocker/client.py @@ -7,7 +7,7 @@ import sys from pathlib import Path from types import TracebackType -from typing import Literal, Optional +from typing import Literal from urllib.parse import urljoin import aiohttp @@ -47,9 +47,9 @@ async def __aenter__(self) -> MonkeyflockerClient: async def __aexit__( self, - exc_type: Optional[type], - exc_val: Optional[Exception], - exc_tb: Optional[TracebackType], + exc_type: type | None, + exc_val: Exception | None, + exc_tb: TracebackType | None, ) -> Literal[False]: await self.aclose() return False diff --git a/tests/autostart_test.py b/tests/autostart_test.py index e5ae2210..b6338176 100644 --- a/tests/autostart_test.py +++ b/tests/autostart_test.py @@ -2,8 +2,8 @@ from __future__ import annotations +from collections.abc import Iterator from pathlib import Path -from typing import Iterator from unittest.mock import ANY import pytest diff --git a/tests/conftest.py b/tests/conftest.py index d972592a..5252f189 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,7 +2,8 @@ from __future__ import annotations -from typing import Any, AsyncIterator, Iterator +from collections.abc import AsyncIterator, Iterator +from typing import Any from unittest.mock import patch import pytest diff --git a/tests/handlers/flock_test.py b/tests/handlers/flock_test.py index dfe5ea77..12fedfc8 100644 --- a/tests/handlers/flock_test.py +++ b/tests/handlers/flock_test.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, Dict +from typing import Any from unittest.mock import ANY import pytest @@ -28,7 +28,7 @@ async def test_start_stop( } r = await client.put("/mobu/flocks", json=config) assert r.status_code == 201 - expected: Dict[str, Any] = { + expected: dict[str, Any] = { "name": "test", "config": { "name": "test", @@ -152,7 +152,7 @@ async def test_user_list( } r = await client.put("/mobu/flocks", json=config) assert r.status_code == 201 - expected: Dict[str, Any] = { + expected: dict[str, Any] = { "name": "test", "config": config, "monkeys": [ diff --git a/tests/monkeyflocker_test.py b/tests/monkeyflocker_test.py index 7a1b10a2..777b209a 100644 --- a/tests/monkeyflocker_test.py +++ b/tests/monkeyflocker_test.py @@ -10,8 +10,9 @@ import socket import subprocess import time +from collections.abc import Iterator from pathlib import Path -from typing import Any, Dict, Iterator +from typing import Any from unittest.mock import ANY import httpx @@ -22,7 +23,7 @@ from monkeyflocker.cli import main APP_SOURCE = """ -from typing import Awaitable, Callable +from collections.abc import Awaitable, Callable from aioresponses import aioresponses from fastapi import FastAPI, Request, Response @@ -151,7 +152,7 @@ def test_start_report_stop(tmp_path: Path, app_url: str) -> None: print(result.stdout) assert result.exit_code == 0 - expected: Dict[str, Any] = { + expected: dict[str, Any] = { "name": "basic", "config": { "name": "basic", diff --git a/tests/support/jupyter.py b/tests/support/jupyter.py index 01a8ba1f..6d0d2e02 100644 --- a/tests/support/jupyter.py +++ b/tests/support/jupyter.py @@ -12,7 +12,7 @@ from io import StringIO from re import Pattern from traceback import format_exc -from typing import Any, Dict, Optional, Union +from typing import Any, Optional from unittest.mock import ANY, AsyncMock, Mock from uuid import uuid4 @@ -47,7 +47,7 @@ class JupyterState(Enum): LAB_RUNNING = "lab running" -def _url(route: str, regex: bool = False) -> Union[str, Pattern[str]]: +def _url(route: str, regex: bool = False) -> str | Pattern[str]: """Construct a URL for JupyterHub/Proxy.""" if not regex: return f"{config.environment_url}/nb/{route}" @@ -64,13 +64,13 @@ class MockJupyter: """ def __init__(self) -> None: - self.sessions: Dict[str, JupyterLabSession] = {} - self.state: Dict[str, JupyterState] = {} + self.sessions: dict[str, JupyterLabSession] = {} + self.state: dict[str, JupyterState] = {} self.delete_immediate = True self.spawn_timeout = False self.redirect_loop = False - self._delete_at: Dict[str, Optional[datetime]] = {} - self._fail: Dict[str, Dict[JupyterAction, bool]] = {} + self._delete_at: dict[str, datetime | None] = {} + self._fail: dict[str, dict[JupyterAction, bool]] = {} def fail(self, user: str, action: JupyterAction) -> None: """Configure the given action to fail for the given user.""" @@ -264,11 +264,11 @@ def __init__(self, user: str, session_id: str) -> None: super().__init__(spec=ClientWebSocketResponse) self.user = user self.session_id = session_id - self._header: Optional[Dict[str, str]] = None + self._header: Optional[dict[str, str]] = None self._code: Optional[str] = None - self._state: Dict[str, Any] = {} + self._state: dict[str, Any] = {} - async def send_json(self, message: Dict[str, Any]) -> None: + async def send_json(self, message: dict[str, Any]) -> None: assert message == { "header": { "username": self.user, @@ -292,7 +292,7 @@ async def send_json(self, message: Dict[str, Any]) -> None: self._header = message["header"] self._code = message["content"]["code"] - async def receive_json(self) -> Dict[str, Any]: + async def receive_json(self) -> dict[str, Any]: assert self._header if self._code == _GET_NODE: self._code = None diff --git a/tests/support/util.py b/tests/support/util.py index 04c61638..9feef641 100644 --- a/tests/support/util.py +++ b/tests/support/util.py @@ -3,7 +3,7 @@ from __future__ import annotations import asyncio -from typing import Any, Dict +from typing import Any from httpx import AsyncClient @@ -12,7 +12,7 @@ async def wait_for_business( client: AsyncClient, username: str -) -> Dict[str, Any]: +) -> dict[str, Any]: """Wait for one loop of business to complete and return its data.""" for _ in range(1, 10): await asyncio.sleep(0.5)