Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed import error when exceptiongroup isn't available #2231

Merged
merged 7 commits into from
Aug 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ types-contextvars==2.4.7.2
types-PyYAML==6.0.12.10
types-dataclasses==0.6.6
pytest==7.4.0
trio==0.22.1
anyio@git+https://github.com/agronholm/anyio.git
trio==0.21.0
anyio==3.7.1

# Documentation
mkdocs==1.4.3
Expand Down
20 changes: 20 additions & 0 deletions starlette/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,20 @@
import functools
import sys
import typing
from contextlib import contextmanager

if sys.version_info >= (3, 10): # pragma: no cover
from typing import TypeGuard
else: # pragma: no cover
from typing_extensions import TypeGuard

has_exceptiongroups = True
if sys.version_info < (3, 11): # pragma: no cover
try:
from exceptiongroup import BaseExceptionGroup
except ImportError:
has_exceptiongroups = False

T = typing.TypeVar("T")
AwaitableCallable = typing.Callable[..., typing.Awaitable[T]]

Expand Down Expand Up @@ -66,3 +74,15 @@ async def __aenter__(self) -> SupportsAsyncCloseType:
async def __aexit__(self, *args: typing.Any) -> typing.Union[None, bool]:
await self.entered.close()
return None


@contextmanager
def collapse_excgroups() -> typing.Generator[None, None, None]:
try:
yield
except BaseException as exc:
if has_exceptiongroups:
while isinstance(exc, BaseExceptionGroup) and len(exc.exceptions) == 1:
exc = exc.exceptions[0] # pragma: no cover

raise exc
19 changes: 2 additions & 17 deletions starlette/middleware/base.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,21 @@
import sys
import typing
from contextlib import contextmanager

import anyio
from anyio.abc import ObjectReceiveStream, ObjectSendStream

from starlette._utils import collapse_excgroups
from starlette.background import BackgroundTask
from starlette.requests import ClientDisconnect, Request
from starlette.responses import ContentStream, Response, StreamingResponse
from starlette.types import ASGIApp, Message, Receive, Scope, Send

if sys.version_info < (3, 11): # pragma: no cover
from exceptiongroup import BaseExceptionGroup

RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
DispatchFunction = typing.Callable[
[Request, RequestResponseEndpoint], typing.Awaitable[Response]
]
T = typing.TypeVar("T")


@contextmanager
def _convert_excgroups() -> typing.Generator[None, None, None]:
try:
yield
except BaseException as exc:
while isinstance(exc, BaseExceptionGroup) and len(exc.exceptions) == 1:
exc = exc.exceptions[0]

raise exc


class _CachedRequest(Request):
"""
If the user calls Request.body() from their dispatch function
Expand Down Expand Up @@ -201,7 +186,7 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]:
response.raw_headers = message["headers"]
return response

with _convert_excgroups():
with collapse_excgroups():
async with anyio.create_task_group() as task_group:
response = await self.dispatch_func(request, call_next)
await response(scope, wrapped_receive, send)
Expand Down
9 changes: 2 additions & 7 deletions tests/middleware/test_wsgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@

import pytest

from starlette._utils import collapse_excgroups
from starlette.middleware.wsgi import WSGIMiddleware, build_environ

if sys.version_info < (3, 11): # pragma: no cover
from exceptiongroup import ExceptionGroup


def hello_world(environ, start_response):
status = "200 OK"
Expand Down Expand Up @@ -69,12 +67,9 @@ def test_wsgi_exception(test_client_factory):
# The HTTP protocol implementations would catch this error and return 500.
app = WSGIMiddleware(raise_exception)
client = test_client_factory(app)
with pytest.raises(ExceptionGroup) as exc:
with pytest.raises(RuntimeError), collapse_excgroups():
client.get("/")

assert len(exc.value.exceptions) == 1
assert isinstance(exc.value.exceptions[0], RuntimeError)


def test_wsgi_exc_info(test_client_factory):
# Note that we're testing the WSGI app directly here.
Expand Down
Loading