diff --git a/sentry_asgi/middleware.py b/sentry_asgi/middleware.py index 37d1117..ccf567e 100644 --- a/sentry_asgi/middleware.py +++ b/sentry_asgi/middleware.py @@ -12,6 +12,7 @@ def __init__(self, app): async def __call__(self, scope, receive, send): hub = sentry_sdk.Hub.current with sentry_sdk.Hub(hub) as hub: + scope["sentry_hub"] = hub with hub.configure_scope() as sentry_scope: processor = functools.partial(self.event_processor, asgi_scope=scope) sentry_scope.add_event_processor(processor) diff --git a/tests/test_middleware.py b/tests/test_middleware.py index 69954ad..6c926eb 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -3,8 +3,9 @@ import sys import pytest -from sentry_sdk import capture_message +from sentry_sdk import Hub, capture_message from starlette.applications import Starlette +from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import PlainTextResponse from starlette.testclient import TestClient @@ -30,6 +31,24 @@ async def hi(request): return app +@pytest.fixture +def app_with_extra_middleware(): + class _ExtraMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request, call_next): + sentry_hub: Hub = request.get("sentry_hub") + if sentry_hub: + with sentry_hub.configure_scope() as scope: + scope.user = {"id": "expected_user_id"} + return await call_next(request) + + app = Starlette() + + app.add_middleware(_ExtraMiddleware) + app.add_middleware(SentryMiddleware) + + return app + + @pytest.mark.skipif(sys.version_info < (3, 7), reason="requires python3.7 or higher") def test_sync_request_data(sentry_init, app, capture_events): sentry_init() @@ -118,3 +137,21 @@ def myerror(request): frame["filename"].endswith("test_middleware.py") for frame in exception["stacktrace"]["frames"] ) + + +def test_sentry_hub_is_set_in_context( + sentry_init, app_with_extra_middleware, capture_events +): + sentry_init() + events = capture_events() + + @app_with_extra_middleware.route("/error") + def myerror(request): + raise ValueError("oh no") + + client = TestClient(app_with_extra_middleware, raise_server_exceptions=False) + client.get("/error") + + event, = events + + assert event["user"]["id"] == "expected_user_id"