Skip to content

Commit

Permalink
🐛 Fix: 新增 Lifespan._on_ready() 供适配器使用 (#2483)
Browse files Browse the repository at this point in the history
Co-authored-by: Ju4tCode <42488585+yanyongyu@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Dec 10, 2023
1 parent 9152740 commit 8f3f385
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 67 deletions.
12 changes: 0 additions & 12 deletions nonebot/drivers/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers import HTTPServerSetup, WebSocketServerSetup

from ._lifespan import LIFESPAN_FUNC, Lifespan

try:
import uvicorn
from fastapi.responses import Response
Expand Down Expand Up @@ -97,8 +95,6 @@ def __init__(self, env: Env, config: NoneBotConfig):

self.fastapi_config: Config = Config(**config.dict())

self._lifespan = Lifespan()

self._server_app = FastAPI(
lifespan=self._lifespan_manager,
openapi_url=self.fastapi_config.fastapi_openapi_url,
Expand Down Expand Up @@ -155,14 +151,6 @@ async def _handle(websocket: WebSocket) -> None:
name=setup.name,
)

@override
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
return self._lifespan.on_startup(func)

@override
def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
return self._lifespan.on_shutdown(func)

@contextlib.asynccontextmanager
async def _lifespan_manager(self, app: FastAPI):
await self._lifespan.startup()
Expand Down
14 changes: 0 additions & 14 deletions nonebot/drivers/none.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
from nonebot.config import Env, Config
from nonebot.drivers import Driver as BaseDriver

from ._lifespan import LIFESPAN_FUNC, Lifespan

HANDLED_SIGNALS = (
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
signal.SIGTERM, # Unix signal 15. Sent by `kill <pid>`.
Expand All @@ -35,8 +33,6 @@ class Driver(BaseDriver):
def __init__(self, env: Env, config: Config):
super().__init__(env, config)

self._lifespan = Lifespan()

self.should_exit: asyncio.Event = asyncio.Event()
self.force_exit: bool = False

Expand All @@ -52,16 +48,6 @@ def logger(self):
"""none driver 使用的 logger"""
return logger

@override
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
"""注册一个启动时执行的函数"""
return self._lifespan.on_startup(func)

@override
def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
"""注册一个停止时执行的函数"""
return self._lifespan.on_shutdown(func)

@override
def run(self, *args, **kwargs):
"""启动 none driver"""
Expand Down
27 changes: 3 additions & 24 deletions nonebot/drivers/quart.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,7 @@
import asyncio
from functools import wraps
from typing_extensions import override
from typing import (
Any,
Dict,
List,
Tuple,
Union,
TypeVar,
Callable,
Optional,
Coroutine,
cast,
)
from typing import Any, Dict, List, Tuple, Union, Optional, cast

from pydantic import BaseSettings

Expand Down Expand Up @@ -57,8 +46,6 @@
"Install with pip: `pip install nonebot2[quart]`"
) from e

_AsyncCallable = TypeVar("_AsyncCallable", bound=Callable[..., Coroutine])


def catch_closed(func):
@wraps(func)
Expand Down Expand Up @@ -102,6 +89,8 @@ def __init__(self, env: Env, config: NoneBotConfig):
self._server_app = Quart(
self.__class__.__qualname__, **self.quart_config.quart_extra
)
self._server_app.before_serving(self._lifespan.startup)
self._server_app.after_serving(self._lifespan.shutdown)

@property
@override
Expand Down Expand Up @@ -150,16 +139,6 @@ async def _handle() -> None:
view_func=_handle,
)

@override
def on_startup(self, func: _AsyncCallable) -> _AsyncCallable:
"""参考文档: [`Startup and Shutdown`](https://pgjones.gitlab.io/quart/how_to_guides/startup_shutdown.html)"""
return self.server_app.before_serving(func) # type: ignore

@override
def on_shutdown(self, func: _AsyncCallable) -> _AsyncCallable:
"""参考文档: [`Startup and Shutdown`](https://pgjones.gitlab.io/quart/how_to_guides/startup_shutdown.html)"""
return self.server_app.after_serving(func) # type: ignore

@override
def run(
self,
Expand Down
4 changes: 4 additions & 0 deletions nonebot/internal/adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Dict, AsyncGenerator

from nonebot.config import Config
from nonebot.internal.driver._lifespan import LIFESPAN_FUNC
from nonebot.internal.driver import (
Driver,
Request,
Expand Down Expand Up @@ -97,6 +98,9 @@ async def websocket(self, setup: Request) -> AsyncGenerator[WebSocket, None]:
async with self.driver.websocket(setup) as ws:
yield ws

def on_ready(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
return self.driver._lifespan.on_ready(func)

@abc.abstractmethod
async def _call_api(self, bot: Bot, api: str, **data: Any) -> Any:
"""`Adapter` 实际调用 api 的逻辑实现函数,实现该方法以调用 api。
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
class Lifespan:
def __init__(self) -> None:
self._startup_funcs: List[LIFESPAN_FUNC] = []
self._ready_funcs: List[LIFESPAN_FUNC] = []
self._shutdown_funcs: List[LIFESPAN_FUNC] = []

def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
Expand All @@ -21,6 +22,10 @@ def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
self._shutdown_funcs.append(func)
return func

def on_ready(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
self._ready_funcs.append(func)
return func

@staticmethod
async def _run_lifespan_func(
funcs: List[LIFESPAN_FUNC],
Expand All @@ -35,6 +40,9 @@ async def startup(self) -> None:
if self._startup_funcs:
await self._run_lifespan_func(self._startup_funcs)

if self._ready_funcs:
await self._run_lifespan_func(self._ready_funcs)

async def shutdown(self) -> None:
if self._shutdown_funcs:
await self._run_lifespan_func(self._shutdown_funcs)
Expand Down
18 changes: 9 additions & 9 deletions nonebot/internal/driver/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import asyncio
from typing_extensions import TypeAlias
from contextlib import AsyncExitStack, asynccontextmanager
from typing import TYPE_CHECKING, Any, Set, Dict, Type, Callable, AsyncGenerator
from typing import TYPE_CHECKING, Any, Set, Dict, Type, AsyncGenerator

from nonebot.log import logger
from nonebot.config import Env, Config
Expand All @@ -16,6 +16,7 @@
T_BotDisconnectionHook,
)

from ._lifespan import LIFESPAN_FUNC, Lifespan
from .model import Request, Response, WebSocket, HTTPServerSetup, WebSocketServerSetup

if TYPE_CHECKING:
Expand Down Expand Up @@ -49,6 +50,7 @@ def __init__(self, env: Env, config: Config):
"""全局配置对象"""
self._bots: Dict[str, "Bot"] = {}
self._bot_tasks: Set[asyncio.Task] = set()
self._lifespan = Lifespan()

def __repr__(self) -> str:
return (
Expand Down Expand Up @@ -100,15 +102,13 @@ def run(self, *args, **kwargs):

self.on_shutdown(self._cleanup)

@abc.abstractmethod
def on_startup(self, func: Callable) -> Callable:
"""注册一个在驱动器启动时执行的函数"""
raise NotImplementedError
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
"""注册一个启动时执行的函数"""
return self._lifespan.on_startup(func)

@abc.abstractmethod
def on_shutdown(self, func: Callable) -> Callable:
"""注册一个在驱动器停止时执行的函数"""
raise NotImplementedError
def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
"""注册一个停止时执行的函数"""
return self._lifespan.on_shutdown(func)

@classmethod
def on_bot_connect(cls, func: T_BotConnectionHook) -> T_BotConnectionHook:
Expand Down
32 changes: 24 additions & 8 deletions tests/test_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
import pytest
from nonebug import App

from utils import FakeAdapter
from nonebot.adapters import Bot
from nonebot.params import Depends
from nonebot.dependencies import Dependent
from nonebot.exception import WebSocketClosed
from nonebot.drivers._lifespan import Lifespan
from nonebot.drivers import (
URL,
Driver,
Expand All @@ -25,34 +25,50 @@


@pytest.mark.asyncio
async def test_lifespan():
lifespan = Lifespan()
@pytest.mark.parametrize(
"driver", [pytest.param("nonebot.drivers.none:Driver", id="none")], indirect=True
)
async def test_lifespan(driver: Driver):
adapter = FakeAdapter(driver)

start_log = []
ready_log = []
shutdown_log = []

@lifespan.on_startup
@driver.on_startup
async def _startup1():
assert start_log == []
start_log.append(1)

@lifespan.on_startup
@driver.on_startup
async def _startup2():
assert start_log == [1]
start_log.append(2)

@lifespan.on_shutdown
@adapter.on_ready
def _ready1():
assert start_log == [1, 2]
assert ready_log == []
ready_log.append(1)

@adapter.on_ready
def _ready2():
assert ready_log == [1]
ready_log.append(2)

@driver.on_shutdown
async def _shutdown1():
assert shutdown_log == []
shutdown_log.append(1)

@lifespan.on_shutdown
@driver.on_shutdown
async def _shutdown2():
assert shutdown_log == [1]
shutdown_log.append(2)

async with lifespan:
async with driver._lifespan:
assert start_log == [1, 2]
assert ready_log == [1, 2]

assert shutdown_log == [1, 2]

Expand Down

0 comments on commit 8f3f385

Please sign in to comment.