diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..a412cb7002 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,18 @@ +Release type: minor + +This release adds support for multipart subscriptions in almost all[^1] of our +http integrations! + +[Multipart subcriptions](https://www.apollographql.com/docs/router/executing-operations/subscription-multipart-protocol/) +are a new protocol from Apollo GraphQL, built on the +[Incremental Delivery over HTTP spec](https://github.com/graphql/graphql-over-http/blob/main/rfcs/IncrementalDelivery.md), +which is also used for `@defer` and `@stream`. + +The main advantage of this protocol is that when using the Apollo Client +libraries you don't need to install any additional dependency, but in future +this feature should make it easier for us to implement `@defer` and `@stream` + +Also, this means that you don't need to use Django Channels for subscription, +since this protocol is based on HTTP we don't need to use websockets. + +[^1]: Flask, Chalice and the sync Django integration don't support this. diff --git a/TWEET.md b/TWEET.md new file mode 100644 index 0000000000..17b5d5a1c2 --- /dev/null +++ b/TWEET.md @@ -0,0 +1,5 @@ +🆕 Release $version is out! Thanks to $contributor for the PR 👏 + +Strawberry GraphQL now supports @apollographql's multipart subscriptions! 🎉 + +Get it here 👉 $release_url diff --git a/docs/README.md b/docs/README.md index 46b1c7b3b8..7114bceb13 100644 --- a/docs/README.md +++ b/docs/README.md @@ -12,6 +12,7 @@ title: Strawberry docs - [Queries](./general/queries.md) - [Mutations](./general/mutations.md) - [Subscriptions](./general/subscriptions.md) +- [Multipart Subscriptions](./general/multipart-subscriptions.md) - [Why](./general/why.md) - [Breaking changes](./breaking-changes.md) - [Upgrading Strawberry](./general/upgrades.md) diff --git a/docs/general/multipart-subscriptions.md b/docs/general/multipart-subscriptions.md new file mode 100644 index 0000000000..bbaeba768c --- /dev/null +++ b/docs/general/multipart-subscriptions.md @@ -0,0 +1,27 @@ +--- +title: Multipart subscriptions +--- + +# Multipart subscriptions + +Strawberry supports subscription over multipart responses. This is an +[alternative protocol](https://www.apollographql.com/docs/router/executing-operations/subscription-multipart-protocol/) +created by [Apollo](https://www.apollographql.com/) to support subscriptions +over HTTP, and it is supported by default by Apollo Client. + +# Support + +We support multipart subscriptions out of the box in the following HTTP +libraries: + +- Django (only in the Async view) +- ASGI +- Litestar +- FastAPI +- AioHTTP +- Quart + +# Usage + +Multipart subscriptions are automatically enabled when using Subscription, so no +additional configuration is required. diff --git a/docs/integrations/index.md b/docs/integrations/index.md new file mode 100644 index 0000000000..698a909c5b --- /dev/null +++ b/docs/integrations/index.md @@ -0,0 +1,12 @@ +# integrations + +WIP: + +| name | Supports sync | Supports async | Supports subscriptions via websockets | Supports subscriptions via multipart HTTP | Supports file uploads | Supports batch queries | +| --------------------------- | ------------- | -------------------- | ------------------------------------- | ----------------------------------------- | --------------------- | ---------------------- | +| [django](//django.md) | ✅ | ✅ (with Async view) | ❌ (use Channels for websockets) | ✅ (From Django 4.2) | ✅ | ❌ | +| [starlette](//starlette.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [aiohttp](//aiohttp.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [flask](//flask.md) | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | +| [channels](//channels.md) | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | +| [fastapi](//fastapi.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | diff --git a/pyproject.toml b/pyproject.toml index 5f5f5f271e..784ee352ec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -173,6 +173,12 @@ filterwarnings = [ # ignoring the text instead of the whole warning because we'd # get an error when django is not installed "ignore:The default value of USE_TZ", + "ignore::DeprecationWarning:pydantic_openapi_schema.*", + "ignore::DeprecationWarning:graphql.*", + "ignore::DeprecationWarning:websockets.*", + "ignore::DeprecationWarning:pydantic.*", + "ignore::UserWarning:pydantic.*", + "ignore::DeprecationWarning:pkg_resources.*", ] [tool.autopub] diff --git a/strawberry/aiohttp/views.py b/strawberry/aiohttp/views.py index 88001c6f3a..56a755b2c9 100644 --- a/strawberry/aiohttp/views.py +++ b/strawberry/aiohttp/views.py @@ -7,10 +7,13 @@ from typing import ( TYPE_CHECKING, Any, + AsyncGenerator, + Callable, Dict, Iterable, Mapping, Optional, + Union, cast, ) @@ -73,11 +76,17 @@ async def get_form_data(self) -> FormData: @property def content_type(self) -> Optional[str]: - return self.request.content_type + return self.headers.get("content-type") class GraphQLView( - AsyncBaseHTTPView[web.Request, web.Response, web.Response, Context, RootValue] + AsyncBaseHTTPView[ + web.Request, + Union[web.Response, web.StreamResponse], + web.Response, + Context, + RootValue, + ] ): # Mark the view as coroutine so that AIOHTTP does not confuse it with a deprecated # bare handler function. @@ -180,5 +189,29 @@ def create_response( return sub_response + async def create_multipart_response( + self, + request: web.Request, + stream: Callable[[], AsyncGenerator[str, None]], + sub_response: web.Response, + ) -> web.StreamResponse: + response = web.StreamResponse( + status=sub_response.status, + headers={ + **sub_response.headers, + "Transfer-Encoding": "chunked", + "Content-type": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json", + }, + ) + + await response.prepare(request) + + async for data in stream(): + await response.write(data.encode()) + + await response.write_eof() + + return response + __all__ = ["GraphQLView"] diff --git a/strawberry/asgi/__init__.py b/strawberry/asgi/__init__.py index ea05911a3a..26ead659e6 100644 --- a/strawberry/asgi/__init__.py +++ b/strawberry/asgi/__init__.py @@ -5,6 +5,8 @@ from typing import ( TYPE_CHECKING, Any, + AsyncIterator, + Callable, Mapping, Optional, Sequence, @@ -14,7 +16,12 @@ from starlette import status from starlette.requests import Request -from starlette.responses import HTMLResponse, PlainTextResponse, Response +from starlette.responses import ( + HTMLResponse, + PlainTextResponse, + Response, + StreamingResponse, +) from starlette.websockets import WebSocket from strawberry.asgi.handlers import ( @@ -213,3 +220,19 @@ def create_response( response.status_code = sub_response.status_code return response + + async def create_multipart_response( + self, + request: Request | WebSocket, + stream: Callable[[], AsyncIterator[str]], + sub_response: Response, + ) -> Response: + return StreamingResponse( + stream(), + status_code=sub_response.status_code, + headers={ + **sub_response.headers, + "Transfer-Encoding": "chunked", + "Content-type": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json", + }, + ) diff --git a/strawberry/channels/handlers/http_handler.py b/strawberry/channels/handlers/http_handler.py index 2788b5d921..e7a96d1d7b 100644 --- a/strawberry/channels/handlers/http_handler.py +++ b/strawberry/channels/handlers/http_handler.py @@ -1,8 +1,3 @@ -"""GraphQLHTTPHandler. - -A consumer to provide a graphql endpoint, and optionally graphiql. -""" - from __future__ import annotations import dataclasses @@ -10,7 +5,17 @@ import warnings from functools import cached_property from io import BytesIO -from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Union +from typing import ( + TYPE_CHECKING, + Any, + AsyncGenerator, + Callable, + Dict, + Mapping, + Optional, + Union, +) +from typing_extensions import assert_never from urllib.parse import parse_qs from django.conf import settings @@ -44,6 +49,14 @@ class ChannelsResponse: headers: Dict[bytes, bytes] = dataclasses.field(default_factory=dict) +@dataclasses.dataclass +class MultipartChannelsResponse: + stream: Callable[[], AsyncGenerator[str, None]] + status: int = 200 + content_type: str = "multipart/mixed;boundary=graphql;subscriptionSpec=1.0" + headers: Dict[bytes, bytes] = dataclasses.field(default_factory=dict) + + @dataclasses.dataclass class ChannelsRequest: consumer: ChannelsConsumer @@ -186,16 +199,28 @@ def create_response( async def handle(self, body: bytes) -> None: request = ChannelsRequest(consumer=self, body=body) try: - response: ChannelsResponse = await self.run(request) + response = await self.run(request) if b"Content-Type" not in response.headers: response.headers[b"Content-Type"] = response.content_type.encode() - await self.send_response( - response.status, - response.content, - headers=response.headers, - ) + if isinstance(response, MultipartChannelsResponse): + response.headers[b"Transfer-Encoding"] = b"chunked" + await self.send_headers(headers=response.headers) + + async for chunk in response.stream(): + await self.send_body(chunk.encode("utf-8"), more_body=True) + + await self.send_body(b"", more_body=False) + + elif isinstance(response, ChannelsResponse): + await self.send_response( + response.status, + response.content, + headers=response.headers, + ) + else: + assert_never(response) except HTTPException as e: await self.send_response(e.status_code, e.reason.encode()) @@ -204,7 +229,7 @@ class GraphQLHTTPConsumer( BaseGraphQLHTTPConsumer, AsyncBaseHTTPView[ ChannelsRequest, - ChannelsResponse, + Union[ChannelsResponse, MultipartChannelsResponse], TemporalResponse, Context, RootValue, @@ -248,6 +273,16 @@ async def get_context( async def get_sub_response(self, request: ChannelsRequest) -> TemporalResponse: return TemporalResponse() + async def create_multipart_response( + self, + request: ChannelsRequest, + stream: Callable[[], AsyncGenerator[str, None]], + sub_response: TemporalResponse, + ) -> MultipartChannelsResponse: + status = sub_response.status_code or 200 + headers = {k.encode(): v.encode() for k, v in sub_response.headers.items()} + return MultipartChannelsResponse(stream=stream, status=status, headers=headers) + async def render_graphql_ide(self, request: ChannelsRequest) -> ChannelsResponse: return ChannelsResponse( content=self.graphql_ide_html.encode(), content_type="text/html" @@ -302,7 +337,7 @@ def run( request: ChannelsRequest, context: Optional[Context] = UNSET, root_value: Optional[RootValue] = UNSET, - ) -> ChannelsResponse: + ) -> ChannelsResponse | MultipartChannelsResponse: return super().run(request, context, root_value) diff --git a/strawberry/django/views.py b/strawberry/django/views.py index 54a16d8fa1..ee831eabcb 100644 --- a/strawberry/django/views.py +++ b/strawberry/django/views.py @@ -5,6 +5,7 @@ from typing import ( TYPE_CHECKING, Any, + AsyncIterator, Callable, Mapping, Optional, @@ -14,8 +15,14 @@ from asgiref.sync import markcoroutinefunction from django.core.serializers.json import DjangoJSONEncoder -from django.http import HttpRequest, HttpResponseNotAllowed, JsonResponse -from django.http.response import HttpResponse +from django.http import ( + HttpRequest, + HttpResponse, + HttpResponseNotAllowed, + JsonResponse, + StreamingHttpResponse, +) +from django.http.response import HttpResponseBase from django.template import RequestContext, Template from django.template.exceptions import TemplateDoesNotExist from django.template.loader import render_to_string @@ -116,7 +123,7 @@ def headers(self) -> Mapping[str, str]: @property def content_type(self) -> Optional[str]: - return self.request.content_type + return self.headers.get("Content-type") async def get_body(self) -> str: return self.request.body.decode() @@ -159,8 +166,9 @@ def __init__( def create_response( self, response_data: GraphQLHTTPResponse, sub_response: HttpResponse - ) -> HttpResponse: + ) -> HttpResponseBase: data = self.encode_json(response_data) + response = HttpResponse( data, content_type="application/json", @@ -177,6 +185,22 @@ def create_response( return response + async def create_multipart_response( + self, + request: HttpRequest, + stream: Callable[[], AsyncIterator[Any]], + sub_response: TemporalHttpResponse, + ) -> HttpResponseBase: + return StreamingHttpResponse( + streaming_content=stream(), + status=sub_response.status_code, + headers={ + **sub_response.headers, + "Transfer-Encoding": "chunked", + "Content-type": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json", + }, + ) + def encode_json(self, response_data: GraphQLHTTPResponse) -> str: return json.dumps(response_data, cls=DjangoJSONEncoder) @@ -184,7 +208,7 @@ def encode_json(self, response_data: GraphQLHTTPResponse) -> str: class GraphQLView( BaseView, SyncBaseHTTPView[ - HttpRequest, HttpResponse, TemporalHttpResponse, Context, RootValue + HttpRequest, HttpResponseBase, TemporalHttpResponse, Context, RootValue ], View, ): @@ -207,7 +231,7 @@ def get_sub_response(self, request: HttpRequest) -> TemporalHttpResponse: @method_decorator(csrf_exempt) def dispatch( self, request: HttpRequest, *args: Any, **kwargs: Any - ) -> Union[HttpResponseNotAllowed, TemplateResponse, HttpResponse]: + ) -> Union[HttpResponseNotAllowed, TemplateResponse, HttpResponseBase]: try: return self.run(request=request) except HTTPException as e: @@ -233,7 +257,7 @@ def render_graphql_ide(self, request: HttpRequest) -> HttpResponse: class AsyncGraphQLView( BaseView, AsyncBaseHTTPView[ - HttpRequest, HttpResponse, TemporalHttpResponse, Context, RootValue + HttpRequest, HttpResponseBase, TemporalHttpResponse, Context, RootValue ], View, ): @@ -266,7 +290,7 @@ async def get_sub_response(self, request: HttpRequest) -> TemporalHttpResponse: @method_decorator(csrf_exempt) async def dispatch( # pyright: ignore self, request: HttpRequest, *args: Any, **kwargs: Any - ) -> Union[HttpResponseNotAllowed, TemplateResponse, HttpResponse]: + ) -> Union[HttpResponseNotAllowed, TemplateResponse, HttpResponseBase]: try: return await self.run(request=request) except HTTPException as e: diff --git a/strawberry/fastapi/router.py b/strawberry/fastapi/router.py index a403c33a36..e25dfcd820 100644 --- a/strawberry/fastapi/router.py +++ b/strawberry/fastapi/router.py @@ -6,6 +6,7 @@ from typing import ( TYPE_CHECKING, Any, + AsyncIterator, Awaitable, Callable, Dict, @@ -25,6 +26,7 @@ JSONResponse, PlainTextResponse, Response, + StreamingResponse, ) from starlette.websockets import WebSocket @@ -330,5 +332,21 @@ def create_response( return response + async def create_multipart_response( + self, + request: Request, + stream: Callable[[], AsyncIterator[str]], + sub_response: Response, + ) -> Response: + return StreamingResponse( + stream(), + status_code=sub_response.status_code, + headers={ + **sub_response.headers, + "Transfer-Encoding": "chunked", + "Content-type": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json", + }, + ) + __all__ = ["GraphQLRouter"] diff --git a/strawberry/flask/views.py b/strawberry/flask/views.py index cb773aeb2c..b855c602ea 100644 --- a/strawberry/flask/views.py +++ b/strawberry/flask/views.py @@ -1,7 +1,14 @@ from __future__ import annotations import warnings -from typing import TYPE_CHECKING, Any, Mapping, Optional, Union, cast +from typing import ( + TYPE_CHECKING, + Any, + Mapping, + Optional, + Union, + cast, +) from flask import Request, Response, render_template_string, request from flask.views import View diff --git a/strawberry/http/async_base_view.py b/strawberry/http/async_base_view.py index 1e4b5fe7e9..e210861a55 100644 --- a/strawberry/http/async_base_view.py +++ b/strawberry/http/async_base_view.py @@ -1,12 +1,17 @@ import abc +import asyncio +import contextlib import json from typing import ( + Any, + AsyncGenerator, Callable, Dict, Generic, List, Mapping, Optional, + Tuple, Union, ) @@ -15,15 +20,20 @@ from strawberry import UNSET from strawberry.exceptions import MissingQueryError from strawberry.file_uploads.utils import replace_placeholders_with_files -from strawberry.http import GraphQLHTTPResponse, GraphQLRequestData, process_result +from strawberry.http import ( + GraphQLHTTPResponse, + GraphQLRequestData, + process_result, +) from strawberry.http.ides import GraphQL_IDE from strawberry.schema.base import BaseSchema from strawberry.schema.exceptions import InvalidOperationTypeError -from strawberry.types import ExecutionResult +from strawberry.types import ExecutionResult, SubscriptionExecutionResult from strawberry.types.graphql import OperationType from .base import BaseView from .exceptions import HTTPException +from .parse_content_type import parse_content_type from .types import FormData, HTTPMethod, QueryParams from .typevars import Context, Request, Response, RootValue, SubResponse @@ -82,9 +92,17 @@ def create_response( @abc.abstractmethod async def render_graphql_ide(self, request: Request) -> Response: ... + async def create_multipart_response( + self, + request: Request, + stream: Callable[[], AsyncGenerator[str, None]], + sub_response: SubResponse, + ) -> Response: + raise ValueError("Multipart responses are not supported") + async def execute_operation( self, request: Request, context: Context, root_value: Optional[RootValue] - ) -> ExecutionResult: + ) -> Union[ExecutionResult, SubscriptionExecutionResult]: request_adapter = self.request_adapter_class(request) try: @@ -178,6 +196,11 @@ async def run( except MissingQueryError as e: raise HTTPException(400, "No GraphQL query found in the request") from e + if isinstance(result, SubscriptionExecutionResult): + stream = self._get_stream(request, result) + + return await self.create_multipart_response(request, stream, sub_response) + response_data = await self.process_result(request=request, result=result) if result.errors: @@ -187,17 +210,107 @@ async def run( response_data=response_data, sub_response=sub_response ) + def encode_multipart_data(self, data: Any, separator: str) -> str: + return "".join( + [ + f"\r\n--{separator}\r\n", + "Content-Type: application/json\r\n\r\n", + self.encode_json(data), + "\n", + ] + ) + + def _stream_with_heartbeat( + self, stream: Callable[[], AsyncGenerator[str, None]] + ) -> Callable[[], AsyncGenerator[str, None]]: + """Adds a heartbeat to the stream, to prevent the connection from closing when there are no messages being sent.""" + queue = asyncio.Queue[Tuple[bool, Any]](1) + + cancelling = False + + async def drain() -> None: + try: + async for item in stream(): + await queue.put((False, item)) + except Exception as e: + if not cancelling: + await queue.put((True, e)) + else: + raise + + async def heartbeat() -> None: + while True: + await queue.put((False, self.encode_multipart_data({}, "graphql"))) + + await asyncio.sleep(5) + + async def merged() -> AsyncGenerator[str, None]: + heartbeat_task = asyncio.create_task(heartbeat()) + task = asyncio.create_task(drain()) + + async def cancel_tasks() -> None: + nonlocal cancelling + cancelling = True + task.cancel() + + with contextlib.suppress(asyncio.CancelledError): + await task + + heartbeat_task.cancel() + + with contextlib.suppress(asyncio.CancelledError): + await heartbeat_task + + try: + while not task.done(): + raised, data = await queue.get() + + if raised: + await cancel_tasks() + raise data + + yield data + finally: + await cancel_tasks() + + return merged + + def _get_stream( + self, + request: Request, + result: SubscriptionExecutionResult, + separator: str = "graphql", + ) -> Callable[[], AsyncGenerator[str, None]]: + async def stream() -> AsyncGenerator[str, None]: + async for value in result: + response = await self.process_result(request, value) + yield self.encode_multipart_data({"payload": response}, separator) + + yield f"\r\n--{separator}--\r\n" + + return self._stream_with_heartbeat(stream) + + async def parse_multipart_subscriptions( + self, request: AsyncHTTPRequestAdapter + ) -> Dict[str, str]: + if request.method == "GET": + return self.parse_query_params(request.query_params) + + return self.parse_json(await request.get_body()) + async def parse_http_body( self, request: AsyncHTTPRequestAdapter ) -> GraphQLRequestData: - content_type = request.content_type or "" + content_type, params = parse_content_type(request.content_type or "") if request.method == "GET": data = self.parse_query_params(request.query_params) elif "application/json" in content_type: data = self.parse_json(await request.get_body()) - elif content_type.startswith("multipart/form-data"): + elif content_type == "multipart/form-data": data = await self.parse_multipart(request) + elif self._is_multipart_subscriptions(content_type, params): + data = await self.parse_multipart_subscriptions(request) else: raise HTTPException(400, "Unsupported content type") diff --git a/strawberry/http/base.py b/strawberry/http/base.py index c29a52d362..7f8e1802bc 100644 --- a/strawberry/http/base.py +++ b/strawberry/http/base.py @@ -69,5 +69,16 @@ def graphql_ide_html(self) -> str: graphql_ide=self.graphql_ide, ) + def _is_multipart_subscriptions( + self, content_type: str, params: Dict[str, str] + ) -> bool: + if content_type != "multipart/mixed": + return False + + if params.get("boundary") != "graphql": + return False + + return params.get("subscriptionspec", "").startswith("1.0") + __all__ = ["BaseView"] diff --git a/strawberry/http/parse_content_type.py b/strawberry/http/parse_content_type.py new file mode 100644 index 0000000000..d28be1a337 --- /dev/null +++ b/strawberry/http/parse_content_type.py @@ -0,0 +1,16 @@ +from email.message import Message +from typing import Dict, Tuple + + +def parse_content_type(content_type: str) -> Tuple[str, Dict[str, str]]: + """Parse a content type header into a mime-type and a dictionary of parameters.""" + email = Message() + email["content-type"] = content_type + + params = email.get_params() + + assert params + + mime_type, _ = params.pop(0) + + return mime_type, dict(params) diff --git a/strawberry/http/sync_base_view.py b/strawberry/http/sync_base_view.py index 1a368b3302..f1ce7ca19a 100644 --- a/strawberry/http/sync_base_view.py +++ b/strawberry/http/sync_base_view.py @@ -16,7 +16,11 @@ from strawberry import UNSET from strawberry.exceptions import MissingQueryError from strawberry.file_uploads.utils import replace_placeholders_with_files -from strawberry.http import GraphQLHTTPResponse, GraphQLRequestData, process_result +from strawberry.http import ( + GraphQLHTTPResponse, + GraphQLRequestData, + process_result, +) from strawberry.http.ides import GraphQL_IDE from strawberry.schema import BaseSchema from strawberry.schema.exceptions import InvalidOperationTypeError @@ -25,6 +29,7 @@ from .base import BaseView from .exceptions import HTTPException +from .parse_content_type import parse_content_type from .types import HTTPMethod, QueryParams from .typevars import Context, Request, Response, RootValue, SubResponse @@ -131,14 +136,19 @@ def parse_multipart(self, request: SyncHTTPRequestAdapter) -> Dict[str, str]: raise HTTPException(400, "File(s) missing in form data") from e def parse_http_body(self, request: SyncHTTPRequestAdapter) -> GraphQLRequestData: - content_type = request.content_type or "" + content_type, params = parse_content_type(request.content_type or "") if request.method == "GET": data = self.parse_query_params(request.query_params) elif "application/json" in content_type: data = self.parse_json(request.body) - elif content_type.startswith("multipart/form-data"): + # TODO: multipart via get? + elif content_type == "multipart/form-data": data = self.parse_multipart(request) + elif self._is_multipart_subscriptions(content_type, params): + raise HTTPException( + 400, "Multipart subcriptions are not supported in sync mode" + ) else: raise HTTPException(400, "Unsupported content type") diff --git a/strawberry/litestar/controller.py b/strawberry/litestar/controller.py index 41a459a061..dc4e37a0af 100644 --- a/strawberry/litestar/controller.py +++ b/strawberry/litestar/controller.py @@ -7,6 +7,8 @@ from typing import ( TYPE_CHECKING, Any, + AsyncIterator, + Callable, Dict, FrozenSet, List, @@ -34,6 +36,7 @@ from litestar.background_tasks import BackgroundTasks from litestar.di import Provide from litestar.exceptions import NotFoundException, ValidationException +from litestar.response.streaming import Stream from litestar.status_codes import HTTP_200_OK from strawberry.exceptions import InvalidCustomContext from strawberry.http.async_base_view import AsyncBaseHTTPView, AsyncHTTPRequestAdapter @@ -183,7 +186,13 @@ def headers(self) -> Mapping[str, str]: @property def content_type(self) -> Optional[str]: - return self.request.content_type[0] + content_type, params = self.request.content_type + + # combine content type and params + if params: + content_type += "; " + "; ".join(f"{k}={v}" for k, v in params.items()) + + return content_type async def get_body(self) -> bytes: return await self.request.body() @@ -271,6 +280,22 @@ def create_response( return response + async def create_multipart_response( + self, + request: Request, + stream: Callable[[], AsyncIterator[str]], + sub_response: Response, + ) -> Response: + return Stream( + stream(), + status_code=sub_response.status_code, + headers={ + **sub_response.headers, + "Transfer-Encoding": "chunked", + "Content-type": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json", + }, + ) + @get(raises=[ValidationException, NotFoundException]) async def handle_http_get( self, diff --git a/strawberry/quart/views.py b/strawberry/quart/views.py index f56a95bfd2..e6938a6034 100644 --- a/strawberry/quart/views.py +++ b/strawberry/quart/views.py @@ -1,6 +1,6 @@ import warnings from collections.abc import Mapping -from typing import TYPE_CHECKING, Optional, cast +from typing import TYPE_CHECKING, AsyncGenerator, Callable, Optional, cast from quart import Request, Response, request from quart.views import View @@ -103,5 +103,21 @@ async def dispatch_request(self) -> "ResponseReturnValue": # type: ignore status=e.status_code, ) + async def create_multipart_response( + self, + request: Request, + stream: Callable[[], AsyncGenerator[str, None]], + sub_response: Response, + ) -> Response: + return ( + stream(), + sub_response.status_code, + { # type: ignore + **sub_response.headers, + "Transfer-Encoding": "chunked", + "Content-type": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json", + }, + ) + __all__ = ["GraphQLView"] diff --git a/strawberry/sanic/views.py b/strawberry/sanic/views.py index 6c2e97f425..83b7d3ca5c 100644 --- a/strawberry/sanic/views.py +++ b/strawberry/sanic/views.py @@ -5,6 +5,8 @@ from typing import ( TYPE_CHECKING, Any, + AsyncGenerator, + Callable, Dict, Mapping, Optional, @@ -161,16 +163,46 @@ def create_response( ) async def post(self, request: Request) -> HTTPResponse: + self.request = request + try: return await self.run(request) except HTTPException as e: return HTTPResponse(e.reason, status=e.status_code) async def get(self, request: Request) -> HTTPResponse: # type: ignore[override] + self.request = request + try: return await self.run(request) except HTTPException as e: return HTTPResponse(e.reason, status=e.status_code) + async def create_multipart_response( + self, + request: Request, + stream: Callable[[], AsyncGenerator[str, None]], + sub_response: TemporalResponse, + ) -> HTTPResponse: + response = await self.request.respond( + content_type="multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json", + status=sub_response.status_code, + headers={ + **sub_response.headers, + "Transfer-Encoding": "chunked", + }, + ) + + async for chunk in stream(): + await response.send(chunk) + + await response.eof() + + # returning the response will basically tell sanic to send it again + # to the client, so we return None to avoid that, and we ignore the type + # error mostly so we don't have to update the types everywhere for this + # corner case + return None # type: ignore + __all__ = ["GraphQLView"] diff --git a/strawberry/schema/base.py b/strawberry/schema/base.py index da15136829..9e040d3856 100644 --- a/strawberry/schema/base.py +++ b/strawberry/schema/base.py @@ -12,7 +12,11 @@ from strawberry.directive import StrawberryDirective from strawberry.schema.schema_converter import GraphQLCoreConverter - from strawberry.types import ExecutionContext, ExecutionResult + from strawberry.types import ( + ExecutionContext, + ExecutionResult, + SubscriptionExecutionResult, + ) from strawberry.types.base import StrawberryObjectDefinition from strawberry.types.enum import EnumDefinition from strawberry.types.graphql import OperationType @@ -39,7 +43,7 @@ async def execute( root_value: Optional[Any] = None, operation_name: Optional[str] = None, allowed_operation_types: Optional[Iterable[OperationType]] = None, - ) -> ExecutionResult: + ) -> Union[ExecutionResult, SubscriptionExecutionResult]: raise NotImplementedError @abstractmethod diff --git a/strawberry/schema/execute.py b/strawberry/schema/execute.py index fe1b1fb7cd..036f50c540 100644 --- a/strawberry/schema/execute.py +++ b/strawberry/schema/execute.py @@ -15,7 +15,7 @@ Union, ) -from graphql import GraphQLError, parse +from graphql import GraphQLError, parse, subscribe from graphql import execute as original_execute from graphql.validation import validate @@ -23,6 +23,7 @@ from strawberry.extensions.runner import SchemaExtensionsRunner from strawberry.schema.validation_rules.one_of import OneOfInputValidationRule from strawberry.types import ExecutionResult +from strawberry.types.graphql import OperationType from .exceptions import InvalidOperationTypeError @@ -36,7 +37,7 @@ from strawberry.extensions import SchemaExtension from strawberry.types import ExecutionContext - from strawberry.types.graphql import OperationType + from strawberry.types.execution import SubscriptionExecutionResult # duplicated because of https://github.com/mkdocstrings/griffe-typingdoc/issues/7 @@ -84,7 +85,7 @@ async def execute( execution_context: ExecutionContext, execution_context_class: Optional[Type[GraphQLExecutionContext]] = None, process_errors: Callable[[List[GraphQLError], Optional[ExecutionContext]], None], -) -> ExecutionResult: +) -> Union[ExecutionResult, SubscriptionExecutionResult]: extensions_runner = SchemaExtensionsRunner( execution_context=execution_context, extensions=list(extensions), @@ -124,16 +125,28 @@ async def execute( async with extensions_runner.executing(): if not execution_context.result: - result = original_execute( - schema, - execution_context.graphql_document, - root_value=execution_context.root_value, - middleware=extensions_runner.as_middleware_manager(), - variable_values=execution_context.variables, - operation_name=execution_context.operation_name, - context_value=execution_context.context, - execution_context_class=execution_context_class, - ) + if execution_context.operation_type == OperationType.SUBSCRIPTION: + # TODO: should we process errors here? + # TODO: make our own wrapper? + return await subscribe( # type: ignore + schema, + execution_context.graphql_document, + root_value=execution_context.root_value, + context_value=execution_context.context, + variable_values=execution_context.variables, + operation_name=execution_context.operation_name, + ) + else: + result = original_execute( + schema, + execution_context.graphql_document, + root_value=execution_context.root_value, + middleware=extensions_runner.as_middleware_manager(), + variable_values=execution_context.variables, + operation_name=execution_context.operation_name, + context_value=execution_context.context, + execution_context_class=execution_context_class, + ) if isawaitable(result): result = await result diff --git a/strawberry/schema/schema.py b/strawberry/schema/schema.py index 4937a2a6eb..ec4bf0d64a 100644 --- a/strawberry/schema/schema.py +++ b/strawberry/schema/schema.py @@ -53,7 +53,7 @@ from strawberry.directive import StrawberryDirective from strawberry.extensions import SchemaExtension - from strawberry.types import ExecutionResult + from strawberry.types import ExecutionResult, SubscriptionExecutionResult from strawberry.types.base import StrawberryType from strawberry.types.enum import EnumDefinition from strawberry.types.field import StrawberryField @@ -284,7 +284,7 @@ async def execute( root_value: Optional[Any] = None, operation_name: Optional[str] = None, allowed_operation_types: Optional[Iterable[OperationType]] = None, - ) -> ExecutionResult: + ) -> Union[ExecutionResult, SubscriptionExecutionResult]: if allowed_operation_types is None: allowed_operation_types = DEFAULT_ALLOWED_OPERATION_TYPES diff --git a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py index fa0cb7c177..1275ecf304 100644 --- a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py @@ -256,7 +256,7 @@ async def handle_subscribe(self, message: SubscribeMessage) -> None: else: # create AsyncGenerator returning a single result async def get_result_source() -> AsyncIterator[ExecutionResult]: - yield await self.schema.execute( + yield await self.schema.execute( # type: ignore query=message.payload.query, variable_values=message.payload.variables, context_value=context, diff --git a/strawberry/types/__init__.py b/strawberry/types/__init__.py index 1de5482d7e..65f055865c 100644 --- a/strawberry/types/__init__.py +++ b/strawberry/types/__init__.py @@ -1,10 +1,12 @@ from .base import get_object_definition, has_object_definition -from .execution import ExecutionContext, ExecutionResult +from .execution import ExecutionContext, ExecutionResult, SubscriptionExecutionResult from .info import Info __all__ = [ "ExecutionContext", "ExecutionResult", + "SubscriptionExecutionResult", + "Info", "Info", "get_object_definition", "has_object_definition", diff --git a/strawberry/types/execution.py b/strawberry/types/execution.py index 3cfda6556d..e1f88dbf21 100644 --- a/strawberry/types/execution.py +++ b/strawberry/types/execution.py @@ -9,8 +9,9 @@ Optional, Tuple, Type, + runtime_checkable, ) -from typing_extensions import TypedDict +from typing_extensions import Protocol, TypedDict from graphql import specified_rules @@ -96,4 +97,18 @@ class ParseOptions(TypedDict): max_tokens: NotRequired[int] -__all__ = ["ExecutionContext", "ExecutionResult", "ParseOptions"] +@runtime_checkable +class SubscriptionExecutionResult(Protocol): + def __aiter__(self) -> SubscriptionExecutionResult: # pragma: no cover + ... + + async def __anext__(self) -> Any: # pragma: no cover + ... + + +__all__ = [ + "ExecutionContext", + "ExecutionResult", + "ParseOptions", + "SubscriptionExecutionResult", +] diff --git a/strawberry/types/graphql.py b/strawberry/types/graphql.py index 92793bed3f..d0bbd0237d 100644 --- a/strawberry/types/graphql.py +++ b/strawberry/types/graphql.py @@ -15,7 +15,11 @@ class OperationType(enum.Enum): @staticmethod def from_http(method: HTTPMethod) -> Set[OperationType]: if method == "GET": - return {OperationType.QUERY} + return { + OperationType.QUERY, + # subscriptions are supported via GET in the multipart protocol + OperationType.SUBSCRIPTION, + } if method == "POST": return { diff --git a/tests/http/clients/async_django.py b/tests/http/clients/async_django.py index 97f8c29027..0e8bfca0ed 100644 --- a/tests/http/clients/async_django.py +++ b/tests/http/clients/async_django.py @@ -1,7 +1,7 @@ from __future__ import annotations from django.core.exceptions import BadRequest, SuspiciousOperation -from django.http import Http404, HttpRequest, HttpResponse +from django.http import Http404, HttpRequest, HttpResponse, StreamingHttpResponse from django.test.client import RequestFactory from strawberry.django.views import AsyncGraphQLView as BaseAsyncGraphQLView @@ -48,16 +48,21 @@ async def _do_request(self, request: RequestFactory) -> Response: try: response = await view(request) except Http404: - return Response( - status_code=404, data=b"Not found", headers=response.headers - ) + return Response(status_code=404, data=b"Not found", headers={}) except (BadRequest, SuspiciousOperation) as e: return Response( - status_code=400, data=e.args[0].encode(), headers=response.headers - ) - else: - return Response( - status_code=response.status_code, - data=response.content, - headers=response.headers, + status_code=400, + data=e.args[0].encode(), + headers={}, ) + data = ( + response.streaming_content + if isinstance(response, StreamingHttpResponse) + else response.content + ) + + return Response( + status_code=response.status_code, + data=data, + headers=response.headers, + ) diff --git a/tests/http/clients/base.py b/tests/http/clients/base.py index 8ccbaa1d62..0ea062c34c 100644 --- a/tests/http/clients/base.py +++ b/tests/http/clients/base.py @@ -1,16 +1,21 @@ import abc +import contextlib import json +import logging from dataclasses import dataclass +from functools import cached_property from io import BytesIO from typing import ( Any, AsyncContextManager, AsyncGenerator, + AsyncIterable, Callable, Dict, List, Mapping, Optional, + Union, ) from typing_extensions import Literal @@ -18,6 +23,8 @@ from strawberry.http.ides import GraphQL_IDE from strawberry.types import ExecutionResult +logger = logging.getLogger("strawberry.test.http_client") + JSON = Dict[str, object] ResultOverrideFunction = Optional[Callable[[ExecutionResult], GraphQLHTTPResponse]] @@ -25,17 +32,68 @@ @dataclass class Response: status_code: int - data: bytes - headers: Mapping[str, str] + data: Union[bytes, AsyncIterable[bytes]] + + def __init__( + self, + status_code: int, + data: Union[bytes, AsyncIterable[bytes]], + *, + headers: Optional[Dict[str, str]] = None, + ) -> None: + self.status_code = status_code + self.data = data + self._headers = headers or {} + + @cached_property + def headers(self) -> Mapping[str, str]: + return {k.lower(): v for k, v in self._headers.items()} + + @property + def is_multipart(self) -> bool: + return self.headers.get("content-type", "").startswith("multipart/mixed") @property def text(self) -> str: + assert isinstance(self.data, bytes) return self.data.decode() @property def json(self) -> JSON: + assert isinstance(self.data, bytes) return json.loads(self.data) + async def streaming_json(self) -> AsyncIterable[JSON]: + if not self.is_multipart: + raise ValueError("Streaming not supported") + + def parse_chunk(text: str) -> Union[JSON, None]: + # TODO: better parsing? :) + with contextlib.suppress(json.JSONDecodeError): + return json.loads(text) + + if isinstance(self.data, AsyncIterable): + chunks = self.data + + async for chunk in chunks: + lines = chunk.decode("utf-8").split("\r\n") + + for text in lines: + if data := parse_chunk(text): + yield data + else: + # TODO: we do this because httpx doesn't support streaming + # it would be nice to fix httpx instead of doing this, + # but we might have the same issue in other clients too + # TODO: better message + logger.warning("Didn't receive a stream, parsing it sync") + + chunks = self.data.decode("utf-8").split("\r\n") + + for chunk in chunks: + if data := parse_chunk(chunk): + yield data + class HttpClient(abc.ABC): @abc.abstractmethod @@ -100,16 +158,18 @@ def _get_headers( headers: Optional[Dict[str, str]], files: Optional[Dict[str, BytesIO]], ) -> Dict[str, str]: - addition_headers = {} + additional_headers = {} + headers = headers or {} - content_type = None + # TODO: fix case sensitivity + content_type = headers.get("content-type") - if method == "post" and not files: + if not content_type and method == "post" and not files: content_type = "application/json" - addition_headers = {"Content-Type": content_type} if content_type else {} + additional_headers = {"Content-Type": content_type} if content_type else {} - return addition_headers if headers is None else {**addition_headers, **headers} + return {**additional_headers, **headers} def _build_body( self, diff --git a/tests/http/clients/fastapi.py b/tests/http/clients/fastapi.py index f271509f40..1a8148c136 100644 --- a/tests/http/clients/fastapi.py +++ b/tests/http/clients/fastapi.py @@ -110,6 +110,14 @@ def create_app(self, **kwargs: Any) -> None: self.client = TestClient(self.app) + async def _handle_response(self, response: Any) -> Response: + # TODO: here we should handle the stream + return Response( + status_code=response.status_code, + data=response.content, + headers=response.headers, + ) + async def _graphql_request( self, method: Literal["get", "post"], @@ -141,11 +149,7 @@ async def _graphql_request( **kwargs, ) - return Response( - status_code=response.status_code, - data=response.content, - headers=response.headers, - ) + return await self._handle_response(response) async def request( self, @@ -155,11 +159,7 @@ async def request( ) -> Response: response = getattr(self.client, method)(url, headers=headers) - return Response( - status_code=response.status_code, - data=response.content, - headers=response.headers, - ) + return await self._handle_response(response) async def get( self, @@ -177,11 +177,7 @@ async def post( ) -> Response: response = self.client.post(url, headers=headers, content=data, json=json) - return Response( - status_code=response.status_code, - data=response.content, - headers=response.headers, - ) + return await self._handle_response(response) @contextlib.asynccontextmanager async def ws_connect( diff --git a/tests/http/test_multipart_subscription.py b/tests/http/test_multipart_subscription.py new file mode 100644 index 0000000000..d9cb289d8c --- /dev/null +++ b/tests/http/test_multipart_subscription.py @@ -0,0 +1,73 @@ +import contextlib +from typing import Type +from typing_extensions import Literal + +import pytest + +from .clients.base import HttpClient + + +@pytest.fixture() +def http_client(http_client_class: Type[HttpClient]) -> HttpClient: + with contextlib.suppress(ImportError): + import django + + if django.VERSION < (4, 2): + pytest.skip(reason="Django < 4.2 doesn't async streaming responses") + + from .clients.django import DjangoHttpClient + + if http_client_class is DjangoHttpClient: + pytest.skip( + reason="(sync) DjangoHttpClient doesn't support multipart subscriptions" + ) + + with contextlib.suppress(ImportError): + from .clients.channels import SyncChannelsHttpClient + + # TODO: why do we have a sync channels client? + if http_client_class is SyncChannelsHttpClient: + pytest.skip( + reason="SyncChannelsHttpClient doesn't support multipart subscriptions" + ) + + with contextlib.suppress(ImportError): + from .clients.async_flask import AsyncFlaskHttpClient + from .clients.flask import FlaskHttpClient + + if http_client_class is FlaskHttpClient: + pytest.skip( + reason="FlaskHttpClient doesn't support multipart subscriptions" + ) + + if http_client_class is AsyncFlaskHttpClient: + pytest.xfail( + reason="AsyncFlaskHttpClient doesn't support multipart subscriptions" + ) + + with contextlib.suppress(ImportError): + from .clients.chalice import ChaliceHttpClient + + if http_client_class is ChaliceHttpClient: + pytest.skip( + reason="ChaliceHttpClient doesn't support multipart subscriptions" + ) + + return http_client_class() + + +@pytest.mark.parametrize("method", ["get", "post"]) +async def test_multipart_subscription( + http_client: HttpClient, method: Literal["get", "post"] +): + response = await http_client.query( + method=method, + query='subscription { echo(message: "Hello world", delay: 0.2) }', + headers={ + "content-type": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json", + }, + ) + + data = [d async for d in response.streaming_json()] + + assert data == [{"payload": {"data": {"echo": "Hello world"}}}] diff --git a/tests/http/test_parse_content_type.py b/tests/http/test_parse_content_type.py new file mode 100644 index 0000000000..4ae0017e40 --- /dev/null +++ b/tests/http/test_parse_content_type.py @@ -0,0 +1,49 @@ +from typing import Dict, Tuple + +import pytest + +from strawberry.http.parse_content_type import parse_content_type + + +@pytest.mark.parametrize( + ("content_type", "expected"), + [ # type: ignore + ("application/json", ("application/json", {})), + ("", ("", {})), + ("application/json; charset=utf-8", ("application/json", {"charset": "utf-8"})), + ( + "application/json; charset=utf-8; boundary=foobar", + ("application/json", {"charset": "utf-8", "boundary": "foobar"}), + ), + ( + "application/json; boundary=foobar; charset=utf-8", + ("application/json", {"boundary": "foobar", "charset": "utf-8"}), + ), + ( + "application/json; boundary=foobar", + ("application/json", {"boundary": "foobar"}), + ), + ( + "application/json; boundary=foobar; charset=utf-8; foo=bar", + ( + "application/json", + {"boundary": "foobar", "charset": "utf-8", "foo": "bar"}, + ), + ), + ( + 'multipart/mixed; boundary="graphql"; subscriptionSpec=1.0, application/json', + ( + "multipart/mixed", + { + "boundary": "graphql", + "subscriptionspec": "1.0, application/json", + }, + ), + ), + ], +) +async def test_parse_content_type( + content_type: str, + expected: Tuple[str, Dict[str, str]], +): + assert parse_content_type(content_type) == expected diff --git a/tests/http/test_query.py b/tests/http/test_query.py index 183aa5ff06..85a9f46889 100644 --- a/tests/http/test_query.py +++ b/tests/http/test_query.py @@ -225,4 +225,4 @@ async def test_updating_headers( assert response.status_code == 200 assert response.json["data"] == {"setHeader": "Jake"} - assert response.headers["X-Name"] == "Jake" + assert response.headers["x-name"] == "Jake"