Skip to content

Commit

Permalink
✨ Feature: 支持 HTTP 客户端会话 (#2627)
Browse files Browse the repository at this point in the history
  • Loading branch information
yanyongyu authored Apr 5, 2024
1 parent 53e2a86 commit 485aa62
Show file tree
Hide file tree
Showing 7 changed files with 417 additions and 62 deletions.
1 change: 1 addition & 0 deletions nonebot/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from nonebot.internal.driver import combine_driver as combine_driver
from nonebot.internal.driver import HTTPClientMixin as HTTPClientMixin
from nonebot.internal.driver import HTTPServerSetup as HTTPServerSetup
from nonebot.internal.driver import HTTPClientSession as HTTPClientSession
from nonebot.internal.driver import WebSocketClientMixin as WebSocketClientMixin
from nonebot.internal.driver import WebSocketServerSetup as WebSocketServerSetup

Expand Down
158 changes: 123 additions & 35 deletions nonebot/drivers/aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,19 @@

from typing_extensions import override
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, AsyncGenerator
from typing import TYPE_CHECKING, Union, Optional, AsyncGenerator

from multidict import CIMultiDict

from nonebot.drivers import Request, Response
from nonebot.exception import WebSocketClosed
from nonebot.drivers import URL, Request, Response
from nonebot.drivers.none import Driver as NoneDriver
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.internal.driver import Cookies, QueryTypes, CookieTypes, HeaderTypes
from nonebot.drivers import (
HTTPVersion,
HTTPClientMixin,
HTTPClientSession,
WebSocketClientMixin,
combine_driver,
)
Expand All @@ -39,52 +43,117 @@
) from e


class Mixin(HTTPClientMixin, WebSocketClientMixin):
"""AIOHTTP Mixin"""
class Session(HTTPClientSession):
@override
def __init__(
self,
params: QueryTypes = None,
headers: HeaderTypes = None,
cookies: CookieTypes = None,
version: Union[str, HTTPVersion] = HTTPVersion.H11,
timeout: Optional[float] = None,
proxy: Optional[str] = None,
):
self._client: Optional[aiohttp.ClientSession] = None

self._params = URL.build(query=params).query if params is not None else None

self._headers = CIMultiDict(headers) if headers is not None else None
self._cookies = tuple(
(cookie.name, cookie.value)
for cookie in Cookies(cookies)
if cookie.value is not None
)

version = HTTPVersion(version)
if version == HTTPVersion.H10:
self._version = aiohttp.HttpVersion10
elif version == HTTPVersion.H11:
self._version = aiohttp.HttpVersion11
else:
raise RuntimeError(f"Unsupported HTTP version: {version}")

self._timeout = timeout
self._proxy = proxy

@property
@override
def type(self) -> str:
return "aiohttp"
def client(self) -> aiohttp.ClientSession:
if self._client is None:
raise RuntimeError("Session is not initialized")
return self._client

@override
async def request(self, setup: Request) -> Response:
if setup.version == HTTPVersion.H10:
version = aiohttp.HttpVersion10
elif setup.version == HTTPVersion.H11:
version = aiohttp.HttpVersion11
if self._params:
params = self._params.copy()
params.update(setup.url.query)
url = setup.url.with_query(params)
else:
raise RuntimeError(f"Unsupported HTTP version: {setup.version}")

timeout = aiohttp.ClientTimeout(setup.timeout)
url = setup.url

data = setup.data
if setup.files:
data = aiohttp.FormData(data or {}, quote_fields=False)
for name, file in setup.files:
data.add_field(name, file[1], content_type=file[2], filename=file[0])

cookies = {
cookie.name: cookie.value for cookie in setup.cookies if cookie.value
}
async with aiohttp.ClientSession(
cookies=cookies, version=version, trust_env=True
) as session:
async with session.request(
setup.method,
setup.url,
data=setup.content or data,
json=setup.json,
headers=setup.headers,
timeout=timeout,
proxy=setup.proxy,
) as response:
return Response(
response.status,
headers=response.headers.copy(),
content=await response.read(),
request=setup,
)
cookies = (
(cookie.name, cookie.value)
for cookie in setup.cookies
if cookie.value is not None
)

timeout = aiohttp.ClientTimeout(setup.timeout)

async with await self.client.request(
setup.method,
url,
data=setup.content or data,
json=setup.json,
cookies=cookies,
headers=setup.headers,
proxy=setup.proxy or self._proxy,
timeout=timeout,
) as response:
return Response(
response.status,
headers=response.headers.copy(),
content=await response.read(),
request=setup,
)

@override
async def setup(self) -> None:
self._client = aiohttp.ClientSession(
cookies=self._cookies,
headers=self._headers,
version=self._version,
timeout=self._timeout,
trust_env=True,
)
await self._client.__aenter__()

@override
async def close(self) -> None:
try:
if self._client is not None:
await self._client.close()
finally:
self._client = None


class Mixin(HTTPClientMixin, WebSocketClientMixin):
"""AIOHTTP Mixin"""

@property
@override
def type(self) -> str:
return "aiohttp"

@override
async def request(self, setup: Request) -> Response:
async with self.get_session() as session:
return await session.request(setup)

@override
@asynccontextmanager
Expand All @@ -106,6 +175,25 @@ async def websocket(self, setup: Request) -> AsyncGenerator["WebSocket", None]:
) as ws:
yield WebSocket(request=setup, session=session, websocket=ws)

@override
def get_session(
self,
params: QueryTypes = None,
headers: HeaderTypes = None,
cookies: CookieTypes = None,
version: Union[str, HTTPVersion] = HTTPVersion.H11,
timeout: Optional[float] = None,
proxy: Optional[str] = None,
) -> Session:
return Session(
params=params,
headers=headers,
cookies=cookies,
version=version,
timeout=timeout,
proxy=proxy,
)


class WebSocket(BaseWebSocket):
"""AIOHTTP Websocket Wrapper"""
Expand Down
123 changes: 100 additions & 23 deletions nonebot/drivers/httpx.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,20 @@
description: nonebot.drivers.httpx 模块
"""

from typing import TYPE_CHECKING
from typing_extensions import override
from typing import TYPE_CHECKING, Union, Optional

from multidict import CIMultiDict

from nonebot.drivers.none import Driver as NoneDriver
from nonebot.internal.driver import Cookies, QueryTypes, CookieTypes, HeaderTypes
from nonebot.drivers import (
URL,
Request,
Response,
HTTPVersion,
HTTPClientMixin,
HTTPClientSession,
combine_driver,
)

Expand All @@ -36,6 +41,77 @@
) from e


class Session(HTTPClientSession):
@override
def __init__(
self,
params: QueryTypes = None,
headers: HeaderTypes = None,
cookies: CookieTypes = None,
version: Union[str, HTTPVersion] = HTTPVersion.H11,
timeout: Optional[float] = None,
proxy: Optional[str] = None,
):
self._client: Optional[httpx.AsyncClient] = None

self._params = (
tuple(URL.build(query=params).query.items()) if params is not None else None
)
self._headers = (
tuple(CIMultiDict(headers).items()) if headers is not None else None
)
self._cookies = Cookies(cookies)
self._version = HTTPVersion(version)
self._timeout = timeout
self._proxy = proxy

@property
def client(self) -> httpx.AsyncClient:
if self._client is None:
raise RuntimeError("Session is not initialized")
return self._client

@override
async def request(self, setup: Request) -> Response:
response = await self.client.request(
setup.method,
str(setup.url),
content=setup.content,
data=setup.data,
files=setup.files,
json=setup.json,
headers=tuple(setup.headers.items()),
cookies=setup.cookies.jar,
timeout=setup.timeout,
)
return Response(
response.status_code,
headers=response.headers.multi_items(),
content=response.content,
request=setup,
)

@override
async def setup(self) -> None:
self._client = httpx.AsyncClient(
params=self._params,
headers=self._headers,
cookies=self._cookies.jar,
http2=self._version == HTTPVersion.H2,
proxies=self._proxy,
follow_redirects=True,
)
await self._client.__aenter__()

@override
async def close(self) -> None:
try:
if self._client is not None:
await self._client.aclose()
finally:
self._client = None


class Mixin(HTTPClientMixin):
"""HTTPX Mixin"""

Expand All @@ -46,28 +122,29 @@ def type(self) -> str:

@override
async def request(self, setup: Request) -> Response:
async with httpx.AsyncClient(
cookies=setup.cookies.jar,
http2=setup.version == HTTPVersion.H2,
proxies=setup.proxy,
follow_redirects=True,
) as client:
response = await client.request(
setup.method,
str(setup.url),
content=setup.content,
data=setup.data,
json=setup.json,
files=setup.files,
headers=tuple(setup.headers.items()),
timeout=setup.timeout,
)
return Response(
response.status_code,
headers=response.headers.multi_items(),
content=response.content,
request=setup,
)
async with self.get_session(
version=setup.version, proxy=setup.proxy
) as session:
return await session.request(setup)

@override
def get_session(
self,
params: QueryTypes = None,
headers: HeaderTypes = None,
cookies: CookieTypes = None,
version: Union[str, HTTPVersion] = HTTPVersion.H11,
timeout: Optional[float] = None,
proxy: Optional[str] = None,
) -> Session:
return Session(
params=params,
headers=headers,
cookies=cookies,
version=version,
timeout=timeout,
proxy=proxy,
)


if TYPE_CHECKING:
Expand Down
1 change: 1 addition & 0 deletions nonebot/internal/driver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,6 @@
from .combine import combine_driver as combine_driver
from .model import HTTPServerSetup as HTTPServerSetup
from .abstract import HTTPClientMixin as HTTPClientMixin
from .abstract import HTTPClientSession as HTTPClientSession
from .model import WebSocketServerSetup as WebSocketServerSetup
from .abstract import WebSocketClientMixin as WebSocketClientMixin
Loading

0 comments on commit 485aa62

Please sign in to comment.