From d3f3ecd0256643305fdb51fcd864297763075d45 Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Wed, 6 Sep 2023 11:49:03 +0100 Subject: [PATCH 01/57] POC for multipart subscriptions --- strawberry/http/async_base_view.py | 24 +++++++++++++++++++++++ tests/http/test_multipart_subscription.py | 20 +++++++++++++++++++ tests/views/schema.py | 2 +- 3 files changed, 45 insertions(+), 1 deletion(-) create mode 100644 tests/http/test_multipart_subscription.py diff --git a/strawberry/http/async_base_view.py b/strawberry/http/async_base_view.py index bab2e9647d..dab44fbed9 100644 --- a/strawberry/http/async_base_view.py +++ b/strawberry/http/async_base_view.py @@ -112,6 +112,14 @@ async def execute_operation( assert self.schema + return await self.schema.subscribe( + request_data.query, + root_value=root_value, + variable_values=request_data.variables, + context_value=context, + operation_name=request_data.operation_name, + ) + return await self.schema.execute( request_data.query, root_value=root_value, @@ -185,10 +193,18 @@ async def run( response_data = await self.process_result(request=request, result=result) + # only if is a multipart subscription + return self.create_multipart_response(response_data, sub_response) + return self.create_response( response_data=response_data, sub_response=sub_response ) + async def parse_multipart_subscriptions( + self, request: AsyncHTTPRequestAdapter + ) -> Dict[str, str]: + return self.parse_json(await request.get_body()) + async def parse_http_body( self, request: AsyncHTTPRequestAdapter ) -> GraphQLRequestData: @@ -198,6 +214,9 @@ async def parse_http_body( data = self.parse_json(await request.get_body()) elif content_type.startswith("multipart/form-data"): data = await self.parse_multipart(request) + elif content_type.startswith("multipart/mixed"): + # TODO: do a check that checks if this is a multipart subscription + data = await self.parse_multipart_subscriptions(request) elif request.method == "GET": data = self.parse_query_params(request.query_params) else: @@ -212,4 +231,9 @@ async def parse_http_body( async def process_result( self, request: Request, result: ExecutionResult ) -> GraphQLHTTPResponse: + # check if result is iterable + if hasattr(result, "__aiter__"): + return [await self.process_result(request, value) async for value in result] + + breakpoint() return process_result(result) diff --git a/tests/http/test_multipart_subscription.py b/tests/http/test_multipart_subscription.py new file mode 100644 index 0000000000..32afc0a1ed --- /dev/null +++ b/tests/http/test_multipart_subscription.py @@ -0,0 +1,20 @@ + + +from .clients.base import HttpClient + +# TODO: do multipart subscriptions work on both GET and POST? + + +async def test_graphql_query(http_client: HttpClient): + response = await http_client.post( + url="/graphql", + json={ + "query": 'subscription { echo2(message: "Hello world", delay: 0.2) }', + }, + headers={ + # TODO: this might just be for django + "CONTENT_TYPE": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json", + }, + ) + + print(response.data) diff --git a/tests/views/schema.py b/tests/views/schema.py index e5a3dc5377..5e6b9aeda0 100644 --- a/tests/views/schema.py +++ b/tests/views/schema.py @@ -151,7 +151,7 @@ def match_text(self, text_file: Upload, pattern: str) -> str: @strawberry.type class Subscription: @strawberry.subscription - async def echo(self, message: str, delay: float = 0) -> AsyncGenerator[str, None]: + async def echo2(self, message: str, delay: float = 0) -> AsyncGenerator[str, None]: await asyncio.sleep(delay) yield message From 865da3f743e63c573387ecd114ad4f945e2e6387 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 6 Sep 2023 10:54:39 +0000 Subject: [PATCH 02/57] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/http/test_multipart_subscription.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/http/test_multipart_subscription.py b/tests/http/test_multipart_subscription.py index 32afc0a1ed..ae6186ef4f 100644 --- a/tests/http/test_multipart_subscription.py +++ b/tests/http/test_multipart_subscription.py @@ -1,5 +1,3 @@ - - from .clients.base import HttpClient # TODO: do multipart subscriptions work on both GET and POST? From 3f6b7cba0dd892086d5586268934a7b44889fbf6 Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Mon, 18 Sep 2023 16:44:32 +0100 Subject: [PATCH 03/57] Fix name --- tests/http/test_multipart_subscription.py | 2 +- tests/views/schema.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/http/test_multipart_subscription.py b/tests/http/test_multipart_subscription.py index ae6186ef4f..786188be79 100644 --- a/tests/http/test_multipart_subscription.py +++ b/tests/http/test_multipart_subscription.py @@ -7,7 +7,7 @@ async def test_graphql_query(http_client: HttpClient): response = await http_client.post( url="/graphql", json={ - "query": 'subscription { echo2(message: "Hello world", delay: 0.2) }', + "query": 'subscription { echo(message: "Hello world", delay: 0.2) }', }, headers={ # TODO: this might just be for django diff --git a/tests/views/schema.py b/tests/views/schema.py index 5e6b9aeda0..e5a3dc5377 100644 --- a/tests/views/schema.py +++ b/tests/views/schema.py @@ -151,7 +151,7 @@ def match_text(self, text_file: Upload, pattern: str) -> str: @strawberry.type class Subscription: @strawberry.subscription - async def echo2(self, message: str, delay: float = 0) -> AsyncGenerator[str, None]: + async def echo(self, message: str, delay: float = 0) -> AsyncGenerator[str, None]: await asyncio.sleep(delay) yield message From 1f24cc0c445d1548001685cdb901e89d7edecac6 Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Mon, 18 Sep 2023 17:18:26 +0100 Subject: [PATCH 04/57] Progress --- strawberry/django/views.py | 16 +++++++++++++++- strawberry/http/async_base_view.py | 20 ++++++++++++++------ tests/http/clients/async_django.py | 9 ++++++++- tests/http/clients/base.py | 6 +++++- tests/http/conftest.py | 1 - tests/http/test_multipart_subscription.py | 6 ++++-- 6 files changed, 46 insertions(+), 12 deletions(-) diff --git a/strawberry/django/views.py b/strawberry/django/views.py index 77f8cfdc4e..dc215750c6 100644 --- a/strawberry/django/views.py +++ b/strawberry/django/views.py @@ -12,7 +12,12 @@ cast, ) -from django.http import HttpRequest, HttpResponseNotAllowed, JsonResponse +from django.http import ( + HttpRequest, + HttpResponseNotAllowed, + JsonResponse, + StreamingHttpResponse, +) from django.http.response import HttpResponse from django.template import RequestContext, Template from django.template.exceptions import TemplateDoesNotExist @@ -176,6 +181,15 @@ def create_response( return response + async def create_multipart_response( + self, response_stream: ..., sub_response: HttpResponse + ) -> HttpResponse: + async def event_stream(): + async for x in response_stream: + yield x + + return StreamingHttpResponse(streaming_content=event_stream()) + class GraphQLView( BaseView, diff --git a/strawberry/http/async_base_view.py b/strawberry/http/async_base_view.py index dab44fbed9..a9deb6e40b 100644 --- a/strawberry/http/async_base_view.py +++ b/strawberry/http/async_base_view.py @@ -92,6 +92,12 @@ def create_response( ) -> Response: ... + @abc.abstractmethod + def create_multipart_response( + self, response_data: GraphQLHTTPResponse, sub_response: SubResponse + ) -> Response: + ... + async def execute_operation( self, request: Request, context: Context, root_value: Optional[RootValue] ) -> ExecutionResult: @@ -191,10 +197,17 @@ async def run( except MissingQueryError as e: raise HTTPException(400, "No GraphQL query found in the request") from e + if hasattr(result, "__aiter__"): + + async def stream(): + async for value in result: + yield await self.process_result(request, value) + + return await self.create_multipart_response(stream(), sub_response) + response_data = await self.process_result(request=request, result=result) # only if is a multipart subscription - return self.create_multipart_response(response_data, sub_response) return self.create_response( response_data=response_data, sub_response=sub_response @@ -231,9 +244,4 @@ async def parse_http_body( async def process_result( self, request: Request, result: ExecutionResult ) -> GraphQLHTTPResponse: - # check if result is iterable - if hasattr(result, "__aiter__"): - return [await self.process_result(request, value) async for value in result] - - breakpoint() return process_result(result) diff --git a/tests/http/clients/async_django.py b/tests/http/clients/async_django.py index 440080f9b9..931629daa8 100644 --- a/tests/http/clients/async_django.py +++ b/tests/http/clients/async_django.py @@ -1,7 +1,7 @@ from __future__ import annotations from django.core.exceptions import BadRequest, SuspiciousOperation -from django.http import Http404, HttpRequest, HttpResponse +from django.http import Http404, HttpRequest, HttpResponse, StreamingHttpResponse from django.test.client import RequestFactory from strawberry.django.views import AsyncGraphQLView as BaseAsyncGraphQLView @@ -54,6 +54,13 @@ async def _do_request(self, request: RequestFactory) -> Response: return Response( status_code=400, data=e.args[0].encode(), headers=response.headers ) + + if isinstance(response, StreamingHttpResponse): + return Response( + status_code=response.status_code, + data=response.streaming_content, + headers=response.headers, + ) else: return Response( status_code=response.status_code, diff --git a/tests/http/clients/base.py b/tests/http/clients/base.py index da4afef762..d0227c2aa1 100644 --- a/tests/http/clients/base.py +++ b/tests/http/clients/base.py @@ -6,11 +6,13 @@ Any, AsyncContextManager, AsyncGenerator, + AsyncIterable, Callable, Dict, List, Mapping, Optional, + Union, ) from typing_extensions import Literal @@ -24,15 +26,17 @@ @dataclass class Response: status_code: int - data: bytes + data: Union[bytes, AsyncIterable[bytes]] headers: Mapping[str, str] @property def text(self) -> str: + assert isinstance(self.data, bytes) return self.data.decode() @property def json(self) -> JSON: + assert isinstance(self.data, bytes) return json.loads(self.data) diff --git a/tests/http/conftest.py b/tests/http/conftest.py index 81d99115f3..cc32163000 100644 --- a/tests/http/conftest.py +++ b/tests/http/conftest.py @@ -18,7 +18,6 @@ def _get_http_client_classes() -> Generator[Any, None, None]: ("FastAPIHttpClient", "fastapi", [pytest.mark.fastapi]), ("FlaskHttpClient", "flask", [pytest.mark.flask]), ("SanicHttpClient", "sanic", [pytest.mark.sanic]), - ("StarliteHttpClient", "starlite", [pytest.mark.starlite]), ( "SyncChannelsHttpClient", "channels", diff --git a/tests/http/test_multipart_subscription.py b/tests/http/test_multipart_subscription.py index 786188be79..80a102c9ed 100644 --- a/tests/http/test_multipart_subscription.py +++ b/tests/http/test_multipart_subscription.py @@ -10,9 +10,11 @@ async def test_graphql_query(http_client: HttpClient): "query": 'subscription { echo(message: "Hello world", delay: 0.2) }', }, headers={ - # TODO: this might just be for django + # TODO: this header might just be for django "CONTENT_TYPE": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json", }, ) - print(response.data) + async for d in response.data: + print(d) + assert response.data == "" From 93e40417c64fba97d041e2a3c74f12f4e1ed77de Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Tue, 19 Sep 2023 07:04:42 -0700 Subject: [PATCH 05/57] WIP --- strawberry/django/views.py | 12 +++++++++--- strawberry/http/async_base_view.py | 3 --- tests/http/clients/base.py | 6 ++++++ tests/http/test_multipart_subscription.py | 6 +++--- 4 files changed, 18 insertions(+), 9 deletions(-) diff --git a/strawberry/django/views.py b/strawberry/django/views.py index dc215750c6..3d0d36ef5c 100644 --- a/strawberry/django/views.py +++ b/strawberry/django/views.py @@ -185,10 +185,16 @@ async def create_multipart_response( self, response_stream: ..., sub_response: HttpResponse ) -> HttpResponse: async def event_stream(): - async for x in response_stream: - yield x + async for data in response_stream: + yield "\r\n--graphql\r\n" + yield "Content-Type: application/json\r\n\r\n" - return StreamingHttpResponse(streaming_content=event_stream()) + yield self.encode_json(data) +"\n" # type: ignore + + return StreamingHttpResponse(streaming_content=event_stream(), headers={ + "Transfer-Encoding": "chunked", + "Content-type": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json" + }) class GraphQLView( diff --git a/strawberry/http/async_base_view.py b/strawberry/http/async_base_view.py index a9deb6e40b..a8c0bdec1d 100644 --- a/strawberry/http/async_base_view.py +++ b/strawberry/http/async_base_view.py @@ -198,7 +198,6 @@ async def run( raise HTTPException(400, "No GraphQL query found in the request") from e if hasattr(result, "__aiter__"): - async def stream(): async for value in result: yield await self.process_result(request, value) @@ -207,8 +206,6 @@ async def stream(): response_data = await self.process_result(request=request, result=result) - # only if is a multipart subscription - return self.create_response( response_data=response_data, sub_response=sub_response ) diff --git a/tests/http/clients/base.py b/tests/http/clients/base.py index d0227c2aa1..d0fd5a6de4 100644 --- a/tests/http/clients/base.py +++ b/tests/http/clients/base.py @@ -39,6 +39,12 @@ def json(self) -> JSON: assert isinstance(self.data, bytes) return json.loads(self.data) + async def streaming_json(self) -> AsyncIterable[JSON]: + assert isinstance(self.data, AsyncIterable) + + async for data in self.data: + yield json.loads(data) + class HttpClient(abc.ABC): @abc.abstractmethod diff --git a/tests/http/test_multipart_subscription.py b/tests/http/test_multipart_subscription.py index 80a102c9ed..0f40f8e336 100644 --- a/tests/http/test_multipart_subscription.py +++ b/tests/http/test_multipart_subscription.py @@ -15,6 +15,6 @@ async def test_graphql_query(http_client: HttpClient): }, ) - async for d in response.data: - print(d) - assert response.data == "" + data = [d async for d in response.streaming_json()] + + assert data == [{'data': {'echo': 'Hello world'}}] From 8ecfae76cc3c2195993dfb1179d157d8c5bb64bd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 19 Sep 2023 14:05:49 +0000 Subject: [PATCH 06/57] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- strawberry/django/views.py | 15 +++++++++------ strawberry/http/async_base_view.py | 1 + tests/http/test_multipart_subscription.py | 2 +- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/strawberry/django/views.py b/strawberry/django/views.py index 3d0d36ef5c..432b996172 100644 --- a/strawberry/django/views.py +++ b/strawberry/django/views.py @@ -189,12 +189,15 @@ async def event_stream(): yield "\r\n--graphql\r\n" yield "Content-Type: application/json\r\n\r\n" - yield self.encode_json(data) +"\n" # type: ignore - - return StreamingHttpResponse(streaming_content=event_stream(), headers={ - "Transfer-Encoding": "chunked", - "Content-type": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json" - }) + yield self.encode_json(data) + "\n" # type: ignore + + return StreamingHttpResponse( + streaming_content=event_stream(), + headers={ + "Transfer-Encoding": "chunked", + "Content-type": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json", + }, + ) class GraphQLView( diff --git a/strawberry/http/async_base_view.py b/strawberry/http/async_base_view.py index a8c0bdec1d..25acba31c3 100644 --- a/strawberry/http/async_base_view.py +++ b/strawberry/http/async_base_view.py @@ -198,6 +198,7 @@ async def run( raise HTTPException(400, "No GraphQL query found in the request") from e if hasattr(result, "__aiter__"): + async def stream(): async for value in result: yield await self.process_result(request, value) diff --git a/tests/http/test_multipart_subscription.py b/tests/http/test_multipart_subscription.py index 0f40f8e336..be3da18b50 100644 --- a/tests/http/test_multipart_subscription.py +++ b/tests/http/test_multipart_subscription.py @@ -17,4 +17,4 @@ async def test_graphql_query(http_client: HttpClient): data = [d async for d in response.streaming_json()] - assert data == [{'data': {'echo': 'Hello world'}}] + assert data == [{"data": {"echo": "Hello world"}}] From 8ca011834ca08a26980df73dde0fdb802ca7f72f Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Sat, 7 Oct 2023 16:10:38 +0100 Subject: [PATCH 07/57] Make test pass with django async --- tests/http/clients/base.py | 17 ++++++++++++++++- tests/http/test_multipart_subscription.py | 2 +- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/tests/http/clients/base.py b/tests/http/clients/base.py index d0fd5a6de4..15b260f7a1 100644 --- a/tests/http/clients/base.py +++ b/tests/http/clients/base.py @@ -1,4 +1,5 @@ import abc +import contextlib import json from dataclasses import dataclass from io import BytesIO @@ -29,6 +30,11 @@ class Response: data: Union[bytes, AsyncIterable[bytes]] headers: Mapping[str, str] + @property + def is_multipart(self) -> bool: + # TODO: check casing + return self.headers.get("Content-type", "").startswith("multipart/mixed") + @property def text(self) -> str: assert isinstance(self.data, bytes) @@ -42,8 +48,17 @@ def json(self) -> JSON: async def streaming_json(self) -> AsyncIterable[JSON]: assert isinstance(self.data, AsyncIterable) + if not self.is_multipart: + raise ValueError("Streaming not supported") + + # assuming we receive lines + async for data in self.data: - yield json.loads(data) + text = data.decode("utf-8").strip() + + # TODO: this is silly, but bear with me :) + with contextlib.suppress(json.JSONDecodeError): + yield json.loads(text) class HttpClient(abc.ABC): diff --git a/tests/http/test_multipart_subscription.py b/tests/http/test_multipart_subscription.py index be3da18b50..dfd6c948bc 100644 --- a/tests/http/test_multipart_subscription.py +++ b/tests/http/test_multipart_subscription.py @@ -10,7 +10,7 @@ async def test_graphql_query(http_client: HttpClient): "query": 'subscription { echo(message: "Hello world", delay: 0.2) }', }, headers={ - # TODO: this header might just be for django + # TODO: this header might just be for django (the way it is written) "CONTENT_TYPE": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json", }, ) From afc45561e8d63f9290a869df9ac09266acd65fe1 Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Sat, 7 Oct 2023 16:46:54 +0100 Subject: [PATCH 08/57] Attempt with FastAPI, type fixes --- strawberry/django/views.py | 26 ++++++++------------ strawberry/fastapi/router.py | 12 +++++++++ strawberry/http/async_base_view.py | 8 ++++-- tests/http/clients/fastapi.py | 30 +++++++++++------------ tests/http/test_multipart_subscription.py | 20 ++++++++++++++- 5 files changed, 61 insertions(+), 35 deletions(-) diff --git a/strawberry/django/views.py b/strawberry/django/views.py index 432b996172..f28bb3c6ea 100644 --- a/strawberry/django/views.py +++ b/strawberry/django/views.py @@ -14,11 +14,12 @@ from django.http import ( HttpRequest, + HttpResponse, HttpResponseNotAllowed, JsonResponse, StreamingHttpResponse, ) -from django.http.response import HttpResponse +from django.http.response import HttpResponseBase from django.template import RequestContext, Template from django.template.exceptions import TemplateDoesNotExist from django.template.loader import render_to_string @@ -147,7 +148,7 @@ def __init__( super().__init__(**kwargs) - def render_graphiql(self, request: HttpRequest) -> HttpResponse: + def render_graphiql(self, request: HttpRequest) -> HttpResponseBase: try: template = Template(render_to_string("graphql/graphiql.html")) except TemplateDoesNotExist: @@ -162,7 +163,7 @@ def render_graphiql(self, request: HttpRequest) -> HttpResponse: def create_response( self, response_data: GraphQLHTTPResponse, sub_response: HttpResponse - ) -> HttpResponse: + ) -> HttpResponseBase: data = self.encode_json(response_data) # type: ignore response = HttpResponse( @@ -183,16 +184,9 @@ def create_response( async def create_multipart_response( self, response_stream: ..., sub_response: HttpResponse - ) -> HttpResponse: - async def event_stream(): - async for data in response_stream: - yield "\r\n--graphql\r\n" - yield "Content-Type: application/json\r\n\r\n" - - yield self.encode_json(data) + "\n" # type: ignore - + ) -> HttpResponseBase: return StreamingHttpResponse( - streaming_content=event_stream(), + streaming_content=response_stream(), headers={ "Transfer-Encoding": "chunked", "Content-type": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json", @@ -203,7 +197,7 @@ async def event_stream(): class GraphQLView( BaseView, SyncBaseHTTPView[ - HttpRequest, HttpResponse, TemporalHttpResponse, Context, RootValue + HttpRequest, HttpResponseBase, TemporalHttpResponse, Context, RootValue ], View, ): @@ -225,7 +219,7 @@ def get_sub_response(self, request: HttpRequest) -> TemporalHttpResponse: @method_decorator(csrf_exempt) def dispatch( self, request: HttpRequest, *args: Any, **kwargs: Any - ) -> Union[HttpResponseNotAllowed, TemplateResponse, HttpResponse]: + ) -> Union[HttpResponseNotAllowed, TemplateResponse, HttpResponseBase]: try: return self.run(request=request) except HTTPException as e: @@ -238,7 +232,7 @@ def dispatch( class AsyncGraphQLView( BaseView, AsyncBaseHTTPView[ - HttpRequest, HttpResponse, TemporalHttpResponse, Context, RootValue + HttpRequest, HttpResponseBase, TemporalHttpResponse, Context, RootValue ], View, ): @@ -269,7 +263,7 @@ async def get_sub_response(self, request: HttpRequest) -> TemporalHttpResponse: @method_decorator(csrf_exempt) async def dispatch( # pyright: ignore self, request: HttpRequest, *args: Any, **kwargs: Any - ) -> Union[HttpResponseNotAllowed, TemplateResponse, HttpResponse]: + ) -> Union[HttpResponseNotAllowed, TemplateResponse, HttpResponseBase]: try: return await self.run(request=request) except HTTPException as e: diff --git a/strawberry/fastapi/router.py b/strawberry/fastapi/router.py index 5ebc552ab0..6418ca6dc6 100644 --- a/strawberry/fastapi/router.py +++ b/strawberry/fastapi/router.py @@ -21,6 +21,7 @@ HTMLResponse, PlainTextResponse, Response, + StreamingResponse, ) from starlette.websockets import WebSocket @@ -311,3 +312,14 @@ def create_response( response.headers.raw.extend(sub_response.headers.raw) return response + + async def create_multipart_response( + self, response_stream: ..., sub_response: Response + ) -> Response: + return StreamingResponse( + response_stream(), + headers={ + "Transfer-Encoding": "chunked", + "Content-type": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json", + }, + ) diff --git a/strawberry/http/async_base_view.py b/strawberry/http/async_base_view.py index 25acba31c3..933e3e5bfe 100644 --- a/strawberry/http/async_base_view.py +++ b/strawberry/http/async_base_view.py @@ -201,9 +201,13 @@ async def run( async def stream(): async for value in result: - yield await self.process_result(request, value) + yield "\r\n--graphql\r\n" + yield "Content-Type: application/json\r\n\r\n" + data = await self.process_result(request, value) - return await self.create_multipart_response(stream(), sub_response) + yield self.encode_json(data) + "\n" # type: ignore + + return await self.create_multipart_response(stream, sub_response) response_data = await self.process_result(request=request, result=result) diff --git a/tests/http/clients/fastapi.py b/tests/http/clients/fastapi.py index bf5d33bf7b..91481eea0e 100644 --- a/tests/http/clients/fastapi.py +++ b/tests/http/clients/fastapi.py @@ -107,6 +107,14 @@ def create_app(self, **kwargs: Any) -> None: self.client = TestClient(self.app) + async def _handle_response(self, response: Any) -> Response: + # TODO: here we should handle the stream + return Response( + status_code=response.status_code, + data=response.content, + headers=response.headers, + ) + async def _graphql_request( self, method: Literal["get", "post"], @@ -138,11 +146,7 @@ async def _graphql_request( **kwargs, ) - return Response( - status_code=response.status_code, - data=response.content, - headers=response.headers, - ) + return await self._handle_response(response) async def request( self, @@ -152,11 +156,7 @@ async def request( ) -> Response: response = getattr(self.client, method)(url, headers=headers) - return Response( - status_code=response.status_code, - data=response.content, - headers=response.headers, - ) + return await self._handle_response(response) async def get( self, @@ -172,14 +172,12 @@ async def post( json: Optional[JSON] = None, headers: Optional[Dict[str, str]] = None, ) -> Response: - response = self.client.post(url, headers=headers, content=data, json=json) - - return Response( - status_code=response.status_code, - data=response.content, - headers=response.headers, + response = self.client.post( + url, headers=headers, content=data, json=json, stream=True ) + return await self._handle_response(response) + @contextlib.asynccontextmanager async def ws_connect( self, diff --git a/tests/http/test_multipart_subscription.py b/tests/http/test_multipart_subscription.py index dfd6c948bc..97d4acaa1b 100644 --- a/tests/http/test_multipart_subscription.py +++ b/tests/http/test_multipart_subscription.py @@ -1,8 +1,26 @@ +import contextlib +from typing import Type + +import pytest + from .clients.base import HttpClient -# TODO: do multipart subscriptions work on both GET and POST? +@pytest.fixture() +def http_client(http_client_class: Type[HttpClient]) -> HttpClient: + with contextlib.suppress(ImportError): + from .clients.fastapi import FastAPIHttpClient + + if http_client_class is FastAPIHttpClient: + # TODO: we could test this, but it doesn't make a lot of sense + # we should fix httpx instead :) + # https://github.com/encode/httpx/issues/2186 + pytest.xfail(reason="HTTPX doesn't support streaming yet") + return http_client_class() + + +# TODO: do multipart subscriptions work on both GET and POST? async def test_graphql_query(http_client: HttpClient): response = await http_client.post( url="/graphql", From 7715736a7341784352b904f015ac996688241513 Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Fri, 13 Oct 2023 18:06:54 -0700 Subject: [PATCH 09/57] Improve code a bit --- strawberry/http/async_base_view.py | 37 +++++++++++++++++++----------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/strawberry/http/async_base_view.py b/strawberry/http/async_base_view.py index 7914c85d2c..40f7153e0a 100644 --- a/strawberry/http/async_base_view.py +++ b/strawberry/http/async_base_view.py @@ -1,6 +1,7 @@ import abc import json from typing import ( + AsyncGenerator, Callable, Dict, Generic, @@ -10,7 +11,7 @@ Union, ) -from graphql import GraphQLError +from graphql import GraphQLError, MapAsyncIterator from strawberry import UNSET from strawberry.exceptions import MissingQueryError @@ -96,14 +97,14 @@ def create_response( ... @abc.abstractmethod - def create_multipart_response( - self, response_data: GraphQLHTTPResponse, sub_response: SubResponse + async def create_multipart_response( + self, stream: Callable[[], AsyncGenerator[str, None]], sub_response: SubResponse ) -> Response: ... async def execute_operation( self, request: Request, context: Context, root_value: Optional[RootValue] - ) -> ExecutionResult: + ) -> Union[ExecutionResult, MapAsyncIterator]: request_adapter = self.request_adapter_class(request) try: @@ -121,6 +122,7 @@ async def execute_operation( assert self.schema + # TODO: check if this is a subscription return await self.schema.subscribe( request_data.query, root_value=root_value, @@ -207,15 +209,9 @@ async def run( except MissingQueryError as e: raise HTTPException(400, "No GraphQL query found in the request") from e - if hasattr(result, "__aiter__"): - - async def stream(): - async for value in result: - yield "\r\n--graphql\r\n" - yield "Content-Type: application/json\r\n\r\n" - data = await self.process_result(request, value) - - yield self.encode_json(data) + "\n" # type: ignore + # TODO: maybe abstract this out? maybe with a protocol + if isinstance(result, MapAsyncIterator): + stream = self._get_stream(request, result) return await self.create_multipart_response(stream, sub_response) @@ -228,6 +224,21 @@ async def stream(): response_data=response_data, sub_response=sub_response ) + def _get_stream( + self, request: Request, result: MapAsyncIterator, separator: str = "graphql" + ) -> Callable[[], AsyncGenerator[str, None]]: + async def stream(): + async for value in result: + yield f"\r\n--{separator}\r\n" + yield "Content-Type: application/json\r\n\r\n" + data = await self.process_result(request, value) + + yield self.encode_json(data) + "\n" # type: ignore + + yield f"\r\n--{separator}--\r\n" + + return stream + async def parse_multipart_subscriptions( self, request: AsyncHTTPRequestAdapter ) -> Dict[str, str]: From 033d0bfe7e77ae4fb583f1289acb66cdb996361e Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Fri, 13 Oct 2023 18:13:26 -0700 Subject: [PATCH 10/57] Improve naming --- strawberry/http/async_base_view.py | 14 ++++++++------ strawberry/types/__init__.py | 9 +++++++-- strawberry/types/execution.py | 12 +++++++++++- 3 files changed, 26 insertions(+), 9 deletions(-) diff --git a/strawberry/http/async_base_view.py b/strawberry/http/async_base_view.py index 40f7153e0a..6954898265 100644 --- a/strawberry/http/async_base_view.py +++ b/strawberry/http/async_base_view.py @@ -11,7 +11,7 @@ Union, ) -from graphql import GraphQLError, MapAsyncIterator +from graphql import GraphQLError from strawberry import UNSET from strawberry.exceptions import MissingQueryError @@ -19,7 +19,7 @@ from strawberry.http import GraphQLHTTPResponse, GraphQLRequestData, process_result from strawberry.schema.base import BaseSchema from strawberry.schema.exceptions import InvalidOperationTypeError -from strawberry.types import ExecutionResult +from strawberry.types import ExecutionResult, SubscriptionExecutionResult from strawberry.types.graphql import OperationType from .base import BaseView @@ -104,7 +104,7 @@ async def create_multipart_response( async def execute_operation( self, request: Request, context: Context, root_value: Optional[RootValue] - ) -> Union[ExecutionResult, MapAsyncIterator]: + ) -> Union[ExecutionResult, SubscriptionExecutionResult]: request_adapter = self.request_adapter_class(request) try: @@ -209,8 +209,7 @@ async def run( except MissingQueryError as e: raise HTTPException(400, "No GraphQL query found in the request") from e - # TODO: maybe abstract this out? maybe with a protocol - if isinstance(result, MapAsyncIterator): + if isinstance(result, SubscriptionExecutionResult): stream = self._get_stream(request, result) return await self.create_multipart_response(stream, sub_response) @@ -225,7 +224,10 @@ async def run( ) def _get_stream( - self, request: Request, result: MapAsyncIterator, separator: str = "graphql" + self, + request: Request, + result: SubscriptionExecutionResult, + separator: str = "graphql", ) -> Callable[[], AsyncGenerator[str, None]]: async def stream(): async for value in result: diff --git a/strawberry/types/__init__.py b/strawberry/types/__init__.py index 0187701c58..72f735d229 100644 --- a/strawberry/types/__init__.py +++ b/strawberry/types/__init__.py @@ -1,4 +1,9 @@ -from .execution import ExecutionContext, ExecutionResult +from .execution import ExecutionContext, ExecutionResult, SubscriptionExecutionResult from .info import Info -__all__ = ["ExecutionContext", "ExecutionResult", "Info"] +__all__ = [ + "ExecutionContext", + "ExecutionResult", + "SubscriptionExecutionResult", + "Info", +] diff --git a/strawberry/types/execution.py b/strawberry/types/execution.py index 08330ad93c..6d9a9bd56e 100644 --- a/strawberry/types/execution.py +++ b/strawberry/types/execution.py @@ -9,8 +9,9 @@ Optional, Tuple, Type, + runtime_checkable, ) -from typing_extensions import TypedDict +from typing_extensions import Protocol, TypedDict from graphql import specified_rules @@ -94,3 +95,12 @@ class ExecutionResult: class ParseOptions(TypedDict): max_tokens: NotRequired[int] + + +@runtime_checkable +class SubscriptionExecutionResult(Protocol): + def __aiter__(self) -> SubscriptionExecutionResult: + ... + + async def __anext__(self) -> Any: + ... From 33104b80306192448162451149765203c53ba00a Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Fri, 13 Oct 2023 20:21:31 -0700 Subject: [PATCH 11/57] Run subs in execute --- strawberry/http/async_base_view.py | 9 ------- strawberry/schema/execute.py | 39 +++++++++++++++++++----------- strawberry/schema/schema.py | 4 +-- 3 files changed, 27 insertions(+), 25 deletions(-) diff --git a/strawberry/http/async_base_view.py b/strawberry/http/async_base_view.py index 6954898265..b8ff3041c0 100644 --- a/strawberry/http/async_base_view.py +++ b/strawberry/http/async_base_view.py @@ -122,15 +122,6 @@ async def execute_operation( assert self.schema - # TODO: check if this is a subscription - return await self.schema.subscribe( - request_data.query, - root_value=root_value, - variable_values=request_data.variables, - context_value=context, - operation_name=request_data.operation_name, - ) - return await self.schema.execute( request_data.query, root_value=root_value, diff --git a/strawberry/schema/execute.py b/strawberry/schema/execute.py index 7a35560410..c6ac3d91ac 100644 --- a/strawberry/schema/execute.py +++ b/strawberry/schema/execute.py @@ -16,13 +16,14 @@ cast, ) -from graphql import GraphQLError, parse +from graphql import GraphQLError, parse, subscribe from graphql import execute as original_execute from graphql.validation import validate from strawberry.exceptions import MissingQueryError from strawberry.extensions.runner import SchemaExtensionsRunner from strawberry.types import ExecutionResult +from strawberry.types.graphql import OperationType from .exceptions import InvalidOperationTypeError @@ -36,9 +37,8 @@ from graphql.validation import ASTValidationRule from strawberry.extensions import SchemaExtension - from strawberry.types import ExecutionContext + from strawberry.types import ExecutionContext, SubscriptionExecutionResult from strawberry.types.execution import ParseOptions - from strawberry.types.graphql import OperationType def parse_document(query: str, **kwargs: Unpack[ParseOptions]) -> DocumentNode: @@ -77,7 +77,7 @@ async def execute( execution_context: ExecutionContext, execution_context_class: Optional[Type[GraphQLExecutionContext]] = None, process_errors: Callable[[List[GraphQLError], Optional[ExecutionContext]], None], -) -> ExecutionResult: +) -> Union[ExecutionResult, SubscriptionExecutionResult]: extensions_runner = SchemaExtensionsRunner( execution_context=execution_context, extensions=list(extensions), @@ -128,16 +128,27 @@ async def execute( async with extensions_runner.executing(): if not execution_context.result: - result = original_execute( - schema, - execution_context.graphql_document, - root_value=execution_context.root_value, - middleware=extensions_runner.as_middleware_manager(), - variable_values=execution_context.variables, - operation_name=execution_context.operation_name, - context_value=execution_context.context, - execution_context_class=execution_context_class, - ) + if execution_context.operation_type == OperationType.SUBSCRIPTION: + # TODO: should we process errors here? + return await subscribe( + schema, + execution_context.graphql_document, + root_value=execution_context.root_value, + context_value=execution_context.context, + variable_values=execution_context.variables, + operation_name=execution_context.operation_name, + ) + else: + result = original_execute( + schema, + execution_context.graphql_document, + root_value=execution_context.root_value, + middleware=extensions_runner.as_middleware_manager(), + variable_values=execution_context.variables, + operation_name=execution_context.operation_name, + context_value=execution_context.context, + execution_context_class=execution_context_class, + ) if isawaitable(result): result = await cast(Awaitable["GraphQLExecutionResult"], result) diff --git a/strawberry/schema/schema.py b/strawberry/schema/schema.py index 8a4c538ee7..edd6f371b0 100644 --- a/strawberry/schema/schema.py +++ b/strawberry/schema/schema.py @@ -55,7 +55,7 @@ from strawberry.extensions import SchemaExtension from strawberry.field import StrawberryField from strawberry.type import StrawberryType - from strawberry.types import ExecutionResult + from strawberry.types import ExecutionResult, SubscriptionExecutionResult from strawberry.union import StrawberryUnion DEFAULT_ALLOWED_OPERATION_TYPES = { @@ -239,7 +239,7 @@ async def execute( root_value: Optional[Any] = None, operation_name: Optional[str] = None, allowed_operation_types: Optional[Iterable[OperationType]] = None, - ) -> ExecutionResult: + ) -> Union[ExecutionResult, SubscriptionExecutionResult]: if allowed_operation_types is None: allowed_operation_types = DEFAULT_ALLOWED_OPERATION_TYPES From b6ca9f9ce5105adde58443a3edbca8e255ad32db Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Fri, 13 Oct 2023 20:36:15 -0700 Subject: [PATCH 12/57] Workaround httpx for now --- strawberry/schema/execute.py | 3 +- tests/http/clients/base.py | 34 +++++++++++++++++------ tests/http/clients/fastapi.py | 4 +-- tests/http/test_multipart_subscription.py | 17 ++++++------ 4 files changed, 37 insertions(+), 21 deletions(-) diff --git a/strawberry/schema/execute.py b/strawberry/schema/execute.py index c6ac3d91ac..d438903e90 100644 --- a/strawberry/schema/execute.py +++ b/strawberry/schema/execute.py @@ -130,7 +130,8 @@ async def execute( if not execution_context.result: if execution_context.operation_type == OperationType.SUBSCRIPTION: # TODO: should we process errors here? - return await subscribe( + # TODO: make our own wrapper? + return await subscribe( # type: ignore - I don't like this ignore schema, execution_context.graphql_document, root_value=execution_context.root_value, diff --git a/tests/http/clients/base.py b/tests/http/clients/base.py index 15b260f7a1..f2c1d19ec1 100644 --- a/tests/http/clients/base.py +++ b/tests/http/clients/base.py @@ -1,6 +1,7 @@ import abc import contextlib import json +import logging from dataclasses import dataclass from io import BytesIO from typing import ( @@ -20,6 +21,8 @@ from strawberry.http import GraphQLHTTPResponse from strawberry.types import ExecutionResult +logger = logging.getLogger("strawberry.test.http_client") + JSON = Dict[str, object] ResultOverrideFunction = Optional[Callable[[ExecutionResult], GraphQLHTTPResponse]] @@ -46,19 +49,34 @@ def json(self) -> JSON: return json.loads(self.data) async def streaming_json(self) -> AsyncIterable[JSON]: - assert isinstance(self.data, AsyncIterable) - if not self.is_multipart: raise ValueError("Streaming not supported") - # assuming we receive lines + def parse_chunk(text: str) -> Union[JSON, None]: + # TODO: better parsing? :) + with contextlib.suppress(json.JSONDecodeError): + return json.loads(text) - async for data in self.data: - text = data.decode("utf-8").strip() + if isinstance(self.data, AsyncIterable): + chunks = self.data - # TODO: this is silly, but bear with me :) - with contextlib.suppress(json.JSONDecodeError): - yield json.loads(text) + async for chunk in chunks: + text = chunk.decode("utf-8").strip() + + if data := parse_chunk(text): + yield data + else: + # TODO: we do this because httpx doesn't support streaming + # it would be nice to fix httpx instead of doing this, + # but we might have the same issue in other clients too + # TODO: better message + logger.warning("Didn't receive a stream, parsing it sync") + + chunks = self.data.decode("utf-8").split("\r\n") + + for chunk in chunks: + if data := parse_chunk(chunk): + yield data class HttpClient(abc.ABC): diff --git a/tests/http/clients/fastapi.py b/tests/http/clients/fastapi.py index 91481eea0e..cc376f81d6 100644 --- a/tests/http/clients/fastapi.py +++ b/tests/http/clients/fastapi.py @@ -172,9 +172,7 @@ async def post( json: Optional[JSON] = None, headers: Optional[Dict[str, str]] = None, ) -> Response: - response = self.client.post( - url, headers=headers, content=data, json=json, stream=True - ) + response = self.client.post(url, headers=headers, content=data, json=json) return await self._handle_response(response) diff --git a/tests/http/test_multipart_subscription.py b/tests/http/test_multipart_subscription.py index 97d4acaa1b..26af4683f0 100644 --- a/tests/http/test_multipart_subscription.py +++ b/tests/http/test_multipart_subscription.py @@ -1,21 +1,20 @@ -import contextlib from typing import Type import pytest from .clients.base import HttpClient +# TODO: httpx doesn't support streaming, so we can only test the full output for now + @pytest.fixture() def http_client(http_client_class: Type[HttpClient]) -> HttpClient: - with contextlib.suppress(ImportError): - from .clients.fastapi import FastAPIHttpClient - - if http_client_class is FastAPIHttpClient: - # TODO: we could test this, but it doesn't make a lot of sense - # we should fix httpx instead :) - # https://github.com/encode/httpx/issues/2186 - pytest.xfail(reason="HTTPX doesn't support streaming yet") + # with contextlib.suppress(ImportError): + + # if http_client_class is FastAPIHttpClient: + # # TODO: we could test this, but it doesn't make a lot of sense + # # we should fix httpx instead :) + # # https://github.com/encode/httpx/issues/2186 return http_client_class() From 84c2413beb213e99c3fb1c30e20f0f0a7f9bf008 Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Sat, 14 Oct 2023 14:01:45 +0100 Subject: [PATCH 13/57] Aiohttp support --- strawberry/aiohttp/views.py | 36 ++++++++++++++++++++++- tests/http/test_multipart_subscription.py | 17 ++++++----- 2 files changed, 45 insertions(+), 8 deletions(-) diff --git a/strawberry/aiohttp/views.py b/strawberry/aiohttp/views.py index 7ff81b4c82..c9c0b4b4fe 100644 --- a/strawberry/aiohttp/views.py +++ b/strawberry/aiohttp/views.py @@ -6,10 +6,13 @@ from typing import ( TYPE_CHECKING, Any, + AsyncGenerator, + Callable, Dict, Iterable, Mapping, Optional, + Union, cast, ) @@ -74,7 +77,13 @@ def content_type(self) -> Optional[str]: class GraphQLView( - AsyncBaseHTTPView[web.Request, web.Response, web.Response, Context, RootValue] + AsyncBaseHTTPView[ + web.Request, + Union[web.Response, web.StreamResponse], + web.Response, + Context, + RootValue, + ] ): # Mark the view as coroutine so that AIOHTTP does not confuse it with a deprecated # bare handler function. @@ -123,6 +132,8 @@ async def __call__(self, request: web.Request) -> web.StreamResponse: if not ws_test.ok: try: + # TODO: pass this down from run to multipart thingy + self.request = request return await self.run(request=request) except HTTPException as e: return web.Response( @@ -169,3 +180,26 @@ def create_response( sub_response.content_type = "application/json" return sub_response + + async def create_multipart_response( + self, + stream: Callable[[], AsyncGenerator[str, None]], + sub_response: web.Response, + ) -> web.StreamResponse: + # TODO: use sub response + response = web.StreamResponse( + status=200, + headers={ + "Transfer-Encoding": "chunked", + "Content-type": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json", + }, + reason="OK", + ) + + await response.prepare(self.request) + + async for data in stream(): + await response.write(data.encode()) + + await response.write_eof() + return response diff --git a/tests/http/test_multipart_subscription.py b/tests/http/test_multipart_subscription.py index 26af4683f0..9e6dd45117 100644 --- a/tests/http/test_multipart_subscription.py +++ b/tests/http/test_multipart_subscription.py @@ -1,20 +1,23 @@ +import contextlib from typing import Type import pytest from .clients.base import HttpClient -# TODO: httpx doesn't support streaming, so we can only test the full output for now - @pytest.fixture() def http_client(http_client_class: Type[HttpClient]) -> HttpClient: - # with contextlib.suppress(ImportError): + with contextlib.suppress(ImportError): + from .clients.channels import SyncChannelsHttpClient + from .clients.django import DjangoHttpClient + + if http_client_class is DjangoHttpClient: + pytest.xfail(reason="(sync) DjangoHttpClient doesn't support subscriptions") - # if http_client_class is FastAPIHttpClient: - # # TODO: we could test this, but it doesn't make a lot of sense - # # we should fix httpx instead :) - # # https://github.com/encode/httpx/issues/2186 + # TODO: why do we have a sync channels client? + if http_client_class is SyncChannelsHttpClient: + pytest.xfail(reason="SyncChannelsHttpClient doesn't support subscriptions") return http_client_class() From b3fced51bdfc20ffce26501cab44c56334556aa5 Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Sat, 14 Oct 2023 14:04:55 +0100 Subject: [PATCH 14/57] ASGI --- strawberry/asgi/__init__.py | 20 +++++++++++++++++++- strawberry/django/views.py | 5 +++-- strawberry/fastapi/router.py | 5 +++-- 3 files changed, 25 insertions(+), 5 deletions(-) diff --git a/strawberry/asgi/__init__.py b/strawberry/asgi/__init__.py index dc2998b259..4f9941b77e 100644 --- a/strawberry/asgi/__init__.py +++ b/strawberry/asgi/__init__.py @@ -4,6 +4,8 @@ from typing import ( TYPE_CHECKING, Any, + AsyncIterator, + Callable, Mapping, Optional, Sequence, @@ -13,7 +15,12 @@ from starlette import status from starlette.requests import Request -from starlette.responses import HTMLResponse, PlainTextResponse, Response +from starlette.responses import ( + HTMLResponse, + PlainTextResponse, + Response, + StreamingResponse, +) from starlette.websockets import WebSocket from strawberry.asgi.handlers import ( @@ -206,3 +213,14 @@ def create_response( response.status_code = sub_response.status_code return response + + async def create_multipart_response( + self, stream: Callable[[], AsyncIterator[str]], sub_response: Response + ) -> Response: + return StreamingResponse( + stream(), + headers={ + "Transfer-Encoding": "chunked", + "Content-type": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json", + }, + ) diff --git a/strawberry/django/views.py b/strawberry/django/views.py index e8f0377be8..d738b27565 100644 --- a/strawberry/django/views.py +++ b/strawberry/django/views.py @@ -4,6 +4,7 @@ from typing import ( TYPE_CHECKING, Any, + AsyncIterator, Callable, Mapping, Optional, @@ -183,10 +184,10 @@ def create_response( return response async def create_multipart_response( - self, response_stream: ..., sub_response: HttpResponse + self, stream: Callable[[], AsyncIterator[Any]], sub_response: HttpResponse ) -> HttpResponseBase: return StreamingHttpResponse( - streaming_content=response_stream(), + streaming_content=stream(), headers={ "Transfer-Encoding": "chunked", "Content-type": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json", diff --git a/strawberry/fastapi/router.py b/strawberry/fastapi/router.py index 6418ca6dc6..d36ccf4b7a 100644 --- a/strawberry/fastapi/router.py +++ b/strawberry/fastapi/router.py @@ -5,6 +5,7 @@ from typing import ( TYPE_CHECKING, Any, + AsyncIterator, Awaitable, Callable, Mapping, @@ -314,10 +315,10 @@ def create_response( return response async def create_multipart_response( - self, response_stream: ..., sub_response: Response + self, stream: Callable[[], AsyncIterator[str]], sub_response: Response ) -> Response: return StreamingResponse( - response_stream(), + stream(), headers={ "Transfer-Encoding": "chunked", "Content-type": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json", From 9821e242b3b9348fb470877caf39cca5945473be Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Sat, 14 Oct 2023 14:21:11 +0100 Subject: [PATCH 15/57] Flask attempt --- strawberry/flask/views.py | 21 ++++++++++++++++++++- tests/http/test_multipart_subscription.py | 14 +++++++++++++- 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/strawberry/flask/views.py b/strawberry/flask/views.py index 315bc35e2c..a8f2012438 100644 --- a/strawberry/flask/views.py +++ b/strawberry/flask/views.py @@ -1,6 +1,16 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, List, Mapping, Optional, Union, cast +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterator, + Callable, + List, + Mapping, + Optional, + Union, + cast, +) from flask import Request, Response, render_template_string, request from flask.views import View @@ -161,3 +171,12 @@ async def dispatch_request(self) -> ResponseReturnValue: # type: ignore response=e.reason, status=e.status_code, ) + + async def create_multipart_response( + self, stream: Callable[[], AsyncIterator[str]], sub_response: Response + ) -> Response: + # TODO: this is not supported :) + return stream(), { + "Transfer-Encoding": "chunked", + "Content-type": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json", + } diff --git a/tests/http/test_multipart_subscription.py b/tests/http/test_multipart_subscription.py index 9e6dd45117..57e0896971 100644 --- a/tests/http/test_multipart_subscription.py +++ b/tests/http/test_multipart_subscription.py @@ -9,16 +9,28 @@ @pytest.fixture() def http_client(http_client_class: Type[HttpClient]) -> HttpClient: with contextlib.suppress(ImportError): - from .clients.channels import SyncChannelsHttpClient from .clients.django import DjangoHttpClient if http_client_class is DjangoHttpClient: pytest.xfail(reason="(sync) DjangoHttpClient doesn't support subscriptions") + with contextlib.suppress(ImportError): + from .clients.channels import SyncChannelsHttpClient + # TODO: why do we have a sync channels client? if http_client_class is SyncChannelsHttpClient: pytest.xfail(reason="SyncChannelsHttpClient doesn't support subscriptions") + with contextlib.suppress(ImportError): + from .clients.async_flask import AsyncFlaskHttpClient + from .clients.flask import FlaskHttpClient + + if http_client_class is FlaskHttpClient: + pytest.xfail(reason="FlaskHttpClient doesn't support subscriptions") + + if http_client_class is AsyncFlaskHttpClient: + pytest.xfail(reason="AsyncFlaskHttpClient doesn't support subscriptions") + return http_client_class() From 0a3557c3f90e6709534e64af56683d9ae5a9f2e3 Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Sat, 14 Oct 2023 16:27:08 +0100 Subject: [PATCH 16/57] Support for sanic --- strawberry/sanic/views.py | 24 +++++++++++++++++++++++ tests/http/test_multipart_subscription.py | 9 ++++++++- 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/strawberry/sanic/views.py b/strawberry/sanic/views.py index 06590124a5..d3783489e3 100644 --- a/strawberry/sanic/views.py +++ b/strawberry/sanic/views.py @@ -5,6 +5,8 @@ from typing import ( TYPE_CHECKING, Any, + AsyncGenerator, + Callable, Dict, List, Mapping, @@ -160,13 +162,35 @@ def create_response( ) async def post(self, request: Request) -> HTTPResponse: + self.request = request + try: return await self.run(request) except HTTPException as e: return HTTPResponse(e.reason, status=e.status_code) async def get(self, request: Request) -> HTTPResponse: + self.request = request try: return await self.run(request) except HTTPException as e: return HTTPResponse(e.reason, status=e.status_code) + + async def create_multipart_response( + self, + stream: Callable[[], AsyncGenerator[str, None]], + sub_response: TemporalResponse, + ) -> HTTPResponse: + response = await self.request.respond( + content_type="multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json", + headers={ + "Transfer-Encoding": "chunked", + }, + ) + + async for chunk in stream(): + await response.send(chunk) + + await response.eof() + + return response diff --git a/tests/http/test_multipart_subscription.py b/tests/http/test_multipart_subscription.py index 57e0896971..b2ad539671 100644 --- a/tests/http/test_multipart_subscription.py +++ b/tests/http/test_multipart_subscription.py @@ -31,6 +31,12 @@ def http_client(http_client_class: Type[HttpClient]) -> HttpClient: if http_client_class is AsyncFlaskHttpClient: pytest.xfail(reason="AsyncFlaskHttpClient doesn't support subscriptions") + with contextlib.suppress(ImportError): + from .clients.chalice import ChaliceHttpClient + + if http_client_class is ChaliceHttpClient: + pytest.xfail(reason="ChaliceHttpClient doesn't support subscriptions") + return http_client_class() @@ -42,7 +48,8 @@ async def test_graphql_query(http_client: HttpClient): "query": 'subscription { echo(message: "Hello world", delay: 0.2) }', }, headers={ - # TODO: this header might just be for django (the way it is written) + # TODO: fix headers :) second one is for django + "content-type": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json", "CONTENT_TYPE": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json", }, ) From c807e96e0b5f683c13acf4fc374ad155ceafd6b6 Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Sat, 14 Oct 2023 17:19:22 +0100 Subject: [PATCH 17/57] Wip channels --- strawberry/channels/handlers/http_handler.py | 61 +++++++++++++++++--- strawberry/http/async_base_view.py | 3 +- tests/http/test_multipart_subscription.py | 18 ++++-- 3 files changed, 66 insertions(+), 16 deletions(-) diff --git a/strawberry/channels/handlers/http_handler.py b/strawberry/channels/handlers/http_handler.py index 72c2406418..ce49bf1cbd 100644 --- a/strawberry/channels/handlers/http_handler.py +++ b/strawberry/channels/handlers/http_handler.py @@ -8,7 +8,16 @@ import json from functools import cached_property from io import BytesIO -from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Union +from typing import ( + TYPE_CHECKING, + Any, + AsyncGenerator, + Callable, + Dict, + Mapping, + Optional, + Union, +) from urllib.parse import parse_qs from django.conf import settings @@ -42,6 +51,14 @@ class ChannelsResponse: headers: Dict[bytes, bytes] = dataclasses.field(default_factory=dict) +@dataclasses.dataclass +class MultipartChannelsResponse: + stream: Callable[[], AsyncGenerator[str, None]] + status: int = 200 + content_type: str = "multipart/mixed;boundary=graphql;subscriptionSpec=1.0" + headers: Dict[bytes, bytes] = dataclasses.field(default_factory=dict) + + @dataclasses.dataclass class ChannelsRequest: consumer: ChannelsConsumer @@ -170,19 +187,37 @@ def create_response( headers={k.encode(): v.encode() for k, v in sub_response.headers.items()}, ) + async def run( + self, + request: Any, + context: Optional[Any] = UNSET, + root_value: Optional[Any] = UNSET, + ) -> Union[ChannelsResponse, MultipartChannelsResponse]: + # putting this here just for type checking + raise NotImplementedError() + async def handle(self, body: bytes) -> None: request = ChannelsRequest(consumer=self, body=body) try: - response: ChannelsResponse = await self.run(request) + response = await self.run(request) if b"Content-Type" not in response.headers: response.headers[b"Content-Type"] = response.content_type.encode() - await self.send_response( - response.status, - response.content, - headers=response.headers, - ) + if isinstance(response, MultipartChannelsResponse): + response.headers[b"Transfer-Encoding"] = b"chunked" + await self.send_headers(headers=response.headers) + + async for chunk in response.stream(): + # TODO: we should change more body + await self.send_body(chunk.encode("utf-8"), more_body=True) + + else: + await self.send_response( + response.status, + response.content, + headers=response.headers, + ) except HTTPException as e: await self.send_response(e.status_code, e.reason.encode()) @@ -191,7 +226,7 @@ class GraphQLHTTPConsumer( BaseGraphQLHTTPConsumer, AsyncBaseHTTPView[ ChannelsRequest, - ChannelsResponse, + Union[ChannelsResponse, MultipartChannelsResponse], TemporalResponse, Context, RootValue, @@ -235,6 +270,14 @@ async def get_context( async def get_sub_response(self, request: ChannelsRequest) -> TemporalResponse: return TemporalResponse() + async def create_multipart_response( + self, + stream: Callable[[], AsyncGenerator[str, None]], + sub_response: TemporalResponse, + ) -> MultipartChannelsResponse: + # TODO: sub response + return MultipartChannelsResponse(stream=stream) + class SyncGraphQLHTTPConsumer( BaseGraphQLHTTPConsumer, @@ -279,5 +322,5 @@ def run( request: ChannelsRequest, context: Optional[Context] = UNSET, root_value: Optional[RootValue] = UNSET, - ) -> ChannelsResponse: + ) -> ChannelsResponse: # type: ignore return super().run(request, context, root_value) diff --git a/strawberry/http/async_base_view.py b/strawberry/http/async_base_view.py index b8ff3041c0..b492affeeb 100644 --- a/strawberry/http/async_base_view.py +++ b/strawberry/http/async_base_view.py @@ -96,11 +96,10 @@ def create_response( ) -> Response: ... - @abc.abstractmethod async def create_multipart_response( self, stream: Callable[[], AsyncGenerator[str, None]], sub_response: SubResponse ) -> Response: - ... + raise ValueError("Multipart responses are not supported1") async def execute_operation( self, request: Request, context: Context, root_value: Optional[RootValue] diff --git a/tests/http/test_multipart_subscription.py b/tests/http/test_multipart_subscription.py index b2ad539671..6d5e61abf4 100644 --- a/tests/http/test_multipart_subscription.py +++ b/tests/http/test_multipart_subscription.py @@ -9,24 +9,32 @@ @pytest.fixture() def http_client(http_client_class: Type[HttpClient]) -> HttpClient: with contextlib.suppress(ImportError): + import django + + if django.VERSION < (4, 2): + pytest.skip(reason="Django < 4.2 doesn't async streaming responses") + from .clients.django import DjangoHttpClient if http_client_class is DjangoHttpClient: - pytest.xfail(reason="(sync) DjangoHttpClient doesn't support subscriptions") + pytest.skip(reason="(sync) DjangoHttpClient doesn't support subscriptions") with contextlib.suppress(ImportError): - from .clients.channels import SyncChannelsHttpClient + from .clients.channels import ChannelsHttpClient, SyncChannelsHttpClient # TODO: why do we have a sync channels client? if http_client_class is SyncChannelsHttpClient: - pytest.xfail(reason="SyncChannelsHttpClient doesn't support subscriptions") + pytest.skip(reason="SyncChannelsHttpClient doesn't support subscriptions") + + if http_client_class is ChannelsHttpClient: + pytest.xfail(reason="ChannelsHttpClient is broken at the moment") with contextlib.suppress(ImportError): from .clients.async_flask import AsyncFlaskHttpClient from .clients.flask import FlaskHttpClient if http_client_class is FlaskHttpClient: - pytest.xfail(reason="FlaskHttpClient doesn't support subscriptions") + pytest.skip(reason="FlaskHttpClient doesn't support subscriptions") if http_client_class is AsyncFlaskHttpClient: pytest.xfail(reason="AsyncFlaskHttpClient doesn't support subscriptions") @@ -35,7 +43,7 @@ def http_client(http_client_class: Type[HttpClient]) -> HttpClient: from .clients.chalice import ChaliceHttpClient if http_client_class is ChaliceHttpClient: - pytest.xfail(reason="ChaliceHttpClient doesn't support subscriptions") + pytest.skip(reason="ChaliceHttpClient doesn't support subscriptions") return http_client_class() From 8f6934e75bf2c485d149946237c9e11e329d271c Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Sat, 14 Oct 2023 17:38:29 +0100 Subject: [PATCH 18/57] Fix syntax --- strawberry/schema/execute.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/strawberry/schema/execute.py b/strawberry/schema/execute.py index d438903e90..c13f974322 100644 --- a/strawberry/schema/execute.py +++ b/strawberry/schema/execute.py @@ -131,7 +131,7 @@ async def execute( if execution_context.operation_type == OperationType.SUBSCRIPTION: # TODO: should we process errors here? # TODO: make our own wrapper? - return await subscribe( # type: ignore - I don't like this ignore + return await subscribe( # type: ignore schema, execution_context.graphql_document, root_value=execution_context.root_value, From 6f932aa70a08b2b79023151f12dbd69c3b3fb120 Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Sat, 14 Oct 2023 17:40:34 +0100 Subject: [PATCH 19/57] Fix various type issues --- strawberry/channels/handlers/http_handler.py | 4 ++-- strawberry/flask/views.py | 11 ----------- strawberry/http/async_base_view.py | 2 +- strawberry/schema/base.py | 8 ++++++-- 4 files changed, 9 insertions(+), 16 deletions(-) diff --git a/strawberry/channels/handlers/http_handler.py b/strawberry/channels/handlers/http_handler.py index ce49bf1cbd..88c95e782d 100644 --- a/strawberry/channels/handlers/http_handler.py +++ b/strawberry/channels/handlers/http_handler.py @@ -322,5 +322,5 @@ def run( request: ChannelsRequest, context: Optional[Context] = UNSET, root_value: Optional[RootValue] = UNSET, - ) -> ChannelsResponse: # type: ignore - return super().run(request, context, root_value) + ) -> ChannelsResponse: # pyright: ignore + return super().run(request, context, root_value) # type: ignore diff --git a/strawberry/flask/views.py b/strawberry/flask/views.py index a8f2012438..b1b7fe9068 100644 --- a/strawberry/flask/views.py +++ b/strawberry/flask/views.py @@ -3,8 +3,6 @@ from typing import ( TYPE_CHECKING, Any, - AsyncIterator, - Callable, List, Mapping, Optional, @@ -171,12 +169,3 @@ async def dispatch_request(self) -> ResponseReturnValue: # type: ignore response=e.reason, status=e.status_code, ) - - async def create_multipart_response( - self, stream: Callable[[], AsyncIterator[str]], sub_response: Response - ) -> Response: - # TODO: this is not supported :) - return stream(), { - "Transfer-Encoding": "chunked", - "Content-type": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json", - } diff --git a/strawberry/http/async_base_view.py b/strawberry/http/async_base_view.py index b492affeeb..028269b6e8 100644 --- a/strawberry/http/async_base_view.py +++ b/strawberry/http/async_base_view.py @@ -225,7 +225,7 @@ async def stream(): yield "Content-Type: application/json\r\n\r\n" data = await self.process_result(request, value) - yield self.encode_json(data) + "\n" # type: ignore + yield self.encode_json(data) + "\n" yield f"\r\n--{separator}--\r\n" diff --git a/strawberry/schema/base.py b/strawberry/schema/base.py index a1c286c6d0..c1b2a28ecc 100644 --- a/strawberry/schema/base.py +++ b/strawberry/schema/base.py @@ -14,7 +14,11 @@ from strawberry.directive import StrawberryDirective from strawberry.enum import EnumDefinition from strawberry.schema.schema_converter import GraphQLCoreConverter - from strawberry.types import ExecutionContext, ExecutionResult + from strawberry.types import ( + ExecutionContext, + ExecutionResult, + SubscriptionExecutionResult, + ) from strawberry.types.graphql import OperationType from strawberry.types.types import StrawberryObjectDefinition from strawberry.union import StrawberryUnion @@ -39,7 +43,7 @@ async def execute( root_value: Optional[Any] = None, operation_name: Optional[str] = None, allowed_operation_types: Optional[Iterable[OperationType]] = None, - ) -> ExecutionResult: + ) -> Union[ExecutionResult, SubscriptionExecutionResult]: raise NotImplementedError @abstractmethod From 5e4e1449106f3189fafdd6bdbc1777fcd88ba52b Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Sun, 15 Oct 2023 10:05:20 +0100 Subject: [PATCH 20/57] Pragma no cover --- strawberry/types/execution.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/strawberry/types/execution.py b/strawberry/types/execution.py index 6d9a9bd56e..e3d6518e55 100644 --- a/strawberry/types/execution.py +++ b/strawberry/types/execution.py @@ -99,8 +99,8 @@ class ParseOptions(TypedDict): @runtime_checkable class SubscriptionExecutionResult(Protocol): - def __aiter__(self) -> SubscriptionExecutionResult: + def __aiter__(self) -> SubscriptionExecutionResult: # pragma: no cover ... - async def __anext__(self) -> Any: + async def __anext__(self) -> Any: # pragma: no cover ... From 66f327d40b35f981646d02031bc0269ac7b4a876 Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Sun, 15 Oct 2023 10:05:33 +0100 Subject: [PATCH 21/57] Initial feature table --- docs/integrations/index.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) create mode 100644 docs/integrations/index.md diff --git a/docs/integrations/index.md b/docs/integrations/index.md new file mode 100644 index 0000000000..09bedc7df9 --- /dev/null +++ b/docs/integrations/index.md @@ -0,0 +1,12 @@ +# integrations + +WIP: + +| name | Supports sync | Supports async | Supports subscriptions via websockets | Supports subscriptions via multipart HTTP | Supports file uploads | Supports batch queries | +| -------------------------------------------- | ------------- | -------------------- | ------------------------------------- | ----------------------------------------- | --------------------- | ---------------------- | +| [django](/docs/integrations/django.md) | ✅ | ✅ (with Async view) | ❌ (use Channels for websockets) | ✅ | ✅ | ❌ | +| [starlette](/docs/integrations/starlette.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [aiohttp](/docs/integrations/aiohttp.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [flask](/docs/integrations/flask.md) | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | +| [channels](/docs/integrations/channels.md) | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | +| [fastapi](/docs/integrations/fastapi.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | From fc7ecd8dfc7ca9a41da52d1183a5827b38d151fe Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Sun, 15 Oct 2023 10:07:43 +0100 Subject: [PATCH 22/57] Relative urls --- docs/integrations/index.md | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/docs/integrations/index.md b/docs/integrations/index.md index 09bedc7df9..bc4b2c377d 100644 --- a/docs/integrations/index.md +++ b/docs/integrations/index.md @@ -2,11 +2,11 @@ WIP: -| name | Supports sync | Supports async | Supports subscriptions via websockets | Supports subscriptions via multipart HTTP | Supports file uploads | Supports batch queries | -| -------------------------------------------- | ------------- | -------------------- | ------------------------------------- | ----------------------------------------- | --------------------- | ---------------------- | -| [django](/docs/integrations/django.md) | ✅ | ✅ (with Async view) | ❌ (use Channels for websockets) | ✅ | ✅ | ❌ | -| [starlette](/docs/integrations/starlette.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [aiohttp](/docs/integrations/aiohttp.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [flask](/docs/integrations/flask.md) | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | -| [channels](/docs/integrations/channels.md) | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | -| [fastapi](/docs/integrations/fastapi.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| name | Supports sync | Supports async | Supports subscriptions via websockets | Supports subscriptions via multipart HTTP | Supports file uploads | Supports batch queries | +| --------------------------- | ------------- | -------------------- | ------------------------------------- | ----------------------------------------- | --------------------- | ---------------------- | +| [django](//django.md) | ✅ | ✅ (with Async view) | ❌ (use Channels for websockets) | ✅ | ✅ | ❌ | +| [starlette](//starlette.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [aiohttp](//aiohttp.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [flask](//flask.md) | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | +| [channels](//channels.md) | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | +| [fastapi](//fastapi.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | From 0b73f58acfaca18ced65ea9a87a13704d247f416 Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Sun, 15 Oct 2023 21:08:00 +0100 Subject: [PATCH 23/57] Fix channels issue --- strawberry/channels/handlers/http_handler.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/strawberry/channels/handlers/http_handler.py b/strawberry/channels/handlers/http_handler.py index 88c95e782d..ae592f42e8 100644 --- a/strawberry/channels/handlers/http_handler.py +++ b/strawberry/channels/handlers/http_handler.py @@ -187,15 +187,6 @@ def create_response( headers={k.encode(): v.encode() for k, v in sub_response.headers.items()}, ) - async def run( - self, - request: Any, - context: Optional[Any] = UNSET, - root_value: Optional[Any] = UNSET, - ) -> Union[ChannelsResponse, MultipartChannelsResponse]: - # putting this here just for type checking - raise NotImplementedError() - async def handle(self, body: bytes) -> None: request = ChannelsRequest(consumer=self, body=body) try: From bec5a7ad99bdcd9a518f2ae6302483057fabf51a Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Fri, 20 Oct 2023 13:51:29 +0400 Subject: [PATCH 24/57] Handle heartbeat --- strawberry/http/async_base_view.py | 72 +++++++++++++++++++++-- tests/http/clients/base.py | 7 ++- tests/http/test_multipart_subscription.py | 2 +- 3 files changed, 72 insertions(+), 9 deletions(-) diff --git a/strawberry/http/async_base_view.py b/strawberry/http/async_base_view.py index 028269b6e8..45a3860e2f 100644 --- a/strawberry/http/async_base_view.py +++ b/strawberry/http/async_base_view.py @@ -1,6 +1,8 @@ import abc +import asyncio import json from typing import ( + Any, AsyncGenerator, Callable, Dict, @@ -8,6 +10,7 @@ List, Mapping, Optional, + Tuple, Union, ) @@ -213,6 +216,68 @@ async def run( response_data=response_data, sub_response=sub_response ) + def encode_multipart_data(self, data: Any, separator: str) -> str: + return "".join( + [ + f"\r\n--{separator}\r\n", + "Content-Type: application/json\r\n\r\n", + self.encode_json(data), + "\n", + ] + ) + + def _stream_with_heartbeat( + self, stream: Callable[[], AsyncGenerator[str, None]] + ) -> Callable[[], AsyncGenerator[str, None]]: + """Adds a heartbeat to the stream, to prevent the connection from closing + when there are no messages being sent.""" + # TODO: handle errors + # TODO: should we do this more efficiently? and only send the heartbeat when + # 5 seconds have passed after the last message? (apollo router doesn't seem to do this) + queue = asyncio.Queue[Tuple[bool, Any]](1) + + cancelling = False + + async def drain(): + try: + async for item in stream(): + await queue.put((False, item)) + except Exception as e: + if not cancelling: + await queue.put((True, e)) + else: + raise + + async def heartbeat(): + while True: + await queue.put((False, self.encode_multipart_data({}, "graphql"))) + + await asyncio.sleep(5) + + async def merged() -> AsyncGenerator[str, None]: + heartbeat_task = asyncio.create_task(heartbeat()) + task = asyncio.create_task(drain()) + + def cancel_tasks(): + nonlocal cancelling + cancelling = True + task.cancel() + heartbeat_task.cancel() + + try: + while not task.done(): + raised, data = await queue.get() + + if raised: + cancel_tasks() + raise data + + yield data + finally: + cancel_tasks() + + return merged + def _get_stream( self, request: Request, @@ -221,15 +286,12 @@ def _get_stream( ) -> Callable[[], AsyncGenerator[str, None]]: async def stream(): async for value in result: - yield f"\r\n--{separator}\r\n" - yield "Content-Type: application/json\r\n\r\n" data = await self.process_result(request, value) - - yield self.encode_json(data) + "\n" + yield self.encode_multipart_data(data, separator) yield f"\r\n--{separator}--\r\n" - return stream + return self._stream_with_heartbeat(stream) async def parse_multipart_subscriptions( self, request: AsyncHTTPRequestAdapter diff --git a/tests/http/clients/base.py b/tests/http/clients/base.py index f2c1d19ec1..7a594f97b7 100644 --- a/tests/http/clients/base.py +++ b/tests/http/clients/base.py @@ -61,10 +61,11 @@ def parse_chunk(text: str) -> Union[JSON, None]: chunks = self.data async for chunk in chunks: - text = chunk.decode("utf-8").strip() + lines = chunk.decode("utf-8").split("\r\n") - if data := parse_chunk(text): - yield data + for text in lines: + if data := parse_chunk(text): + yield data else: # TODO: we do this because httpx doesn't support streaming # it would be nice to fix httpx instead of doing this, diff --git a/tests/http/test_multipart_subscription.py b/tests/http/test_multipart_subscription.py index 6d5e61abf4..6e8e95498a 100644 --- a/tests/http/test_multipart_subscription.py +++ b/tests/http/test_multipart_subscription.py @@ -49,7 +49,7 @@ def http_client(http_client_class: Type[HttpClient]) -> HttpClient: # TODO: do multipart subscriptions work on both GET and POST? -async def test_graphql_query(http_client: HttpClient): +async def test_multipart_subscription(http_client: HttpClient): response = await http_client.post( url="/graphql", json={ From b73dd1eff005bf91f069e27443328a939481a8d5 Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Fri, 20 Oct 2023 14:00:33 +0400 Subject: [PATCH 25/57] Remove type ignore --- strawberry/channels/handlers/http_handler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/strawberry/channels/handlers/http_handler.py b/strawberry/channels/handlers/http_handler.py index ae592f42e8..f3bbd73734 100644 --- a/strawberry/channels/handlers/http_handler.py +++ b/strawberry/channels/handlers/http_handler.py @@ -314,4 +314,4 @@ def run( context: Optional[Context] = UNSET, root_value: Optional[RootValue] = UNSET, ) -> ChannelsResponse: # pyright: ignore - return super().run(request, context, root_value) # type: ignore + return super().run(request, context, root_value) From 52b2f2d467660baa99cc0b5f55ad1a83f740b8b9 Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Fri, 20 Oct 2023 14:10:04 +0400 Subject: [PATCH 26/57] Add blank release file --- RELEASE.md | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 RELEASE.md diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..de7cc70ac5 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,3 @@ +Release type: minor + +TODO :) From 28321b8d4cb02a4e10f608abce38d206dd533c1f Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Fri, 20 Oct 2023 18:21:27 +0400 Subject: [PATCH 27/57] Wrap response in payload --- strawberry/http/async_base_view.py | 4 ++-- tests/http/test_multipart_subscription.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/strawberry/http/async_base_view.py b/strawberry/http/async_base_view.py index 45a3860e2f..8f8bbbaa9d 100644 --- a/strawberry/http/async_base_view.py +++ b/strawberry/http/async_base_view.py @@ -286,8 +286,8 @@ def _get_stream( ) -> Callable[[], AsyncGenerator[str, None]]: async def stream(): async for value in result: - data = await self.process_result(request, value) - yield self.encode_multipart_data(data, separator) + response = await self.process_result(request, value) + yield self.encode_multipart_data({"payload": response}, separator) yield f"\r\n--{separator}--\r\n" diff --git a/tests/http/test_multipart_subscription.py b/tests/http/test_multipart_subscription.py index 6e8e95498a..f4219b5f0d 100644 --- a/tests/http/test_multipart_subscription.py +++ b/tests/http/test_multipart_subscription.py @@ -64,4 +64,4 @@ async def test_multipart_subscription(http_client: HttpClient): data = [d async for d in response.streaming_json()] - assert data == [{"data": {"echo": "Hello world"}}] + assert data == [{"payload": {"data": {"echo": "Hello world"}}}] From 36b121ea69097224f5c477a257ab28be39abb99e Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Tue, 24 Oct 2023 18:34:50 +0100 Subject: [PATCH 28/57] Update integrations --- docs/integrations/index.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/integrations/index.md b/docs/integrations/index.md index bc4b2c377d..698a909c5b 100644 --- a/docs/integrations/index.md +++ b/docs/integrations/index.md @@ -4,7 +4,7 @@ WIP: | name | Supports sync | Supports async | Supports subscriptions via websockets | Supports subscriptions via multipart HTTP | Supports file uploads | Supports batch queries | | --------------------------- | ------------- | -------------------- | ------------------------------------- | ----------------------------------------- | --------------------- | ---------------------- | -| [django](//django.md) | ✅ | ✅ (with Async view) | ❌ (use Channels for websockets) | ✅ | ✅ | ❌ | +| [django](//django.md) | ✅ | ✅ (with Async view) | ❌ (use Channels for websockets) | ✅ (From Django 4.2) | ✅ | ❌ | | [starlette](//starlette.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | [aiohttp](//aiohttp.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | [flask](//flask.md) | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | From 5a4b4c4e1f91f87dd341109698e835506273bb40 Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Wed, 8 Nov 2023 14:30:57 +0100 Subject: [PATCH 29/57] Improve how we check for content types --- strawberry/aiohttp/views.py | 2 +- strawberry/django/views.py | 2 +- strawberry/http/async_base_view.py | 16 +++++--- strawberry/http/base.py | 13 ++++++ strawberry/http/parse_content_type.py | 17 ++++++++ strawberry/http/sync_base_view.py | 17 ++++++-- strawberry/quart/views.py | 10 ++++- tests/http/test_multipart_subscription.py | 2 - tests/http/test_parse_content_type.py | 49 +++++++++++++++++++++++ 9 files changed, 113 insertions(+), 15 deletions(-) create mode 100644 strawberry/http/parse_content_type.py create mode 100644 tests/http/test_parse_content_type.py diff --git a/strawberry/aiohttp/views.py b/strawberry/aiohttp/views.py index 5e06181715..a01ee2cbda 100644 --- a/strawberry/aiohttp/views.py +++ b/strawberry/aiohttp/views.py @@ -74,7 +74,7 @@ async def get_form_data(self) -> FormData: @property def content_type(self) -> Optional[str]: - return self.request.content_type + return self.headers.get("content-type") class GraphQLView( diff --git a/strawberry/django/views.py b/strawberry/django/views.py index 19c0955cf4..e9e67565df 100644 --- a/strawberry/django/views.py +++ b/strawberry/django/views.py @@ -122,7 +122,7 @@ def headers(self) -> Mapping[str, str]: @property def content_type(self) -> Optional[str]: - return self.request.content_type + return self.headers.get("Content-type") async def get_body(self) -> str: return self.request.body.decode() diff --git a/strawberry/http/async_base_view.py b/strawberry/http/async_base_view.py index 3ac43e7767..94288d252f 100644 --- a/strawberry/http/async_base_view.py +++ b/strawberry/http/async_base_view.py @@ -19,7 +19,11 @@ from strawberry import UNSET from strawberry.exceptions import MissingQueryError from strawberry.file_uploads.utils import replace_placeholders_with_files -from strawberry.http import GraphQLHTTPResponse, GraphQLRequestData, process_result +from strawberry.http import ( + GraphQLHTTPResponse, + GraphQLRequestData, + process_result, +) from strawberry.http.ides import GraphQL_IDE from strawberry.schema.base import BaseSchema from strawberry.schema.exceptions import InvalidOperationTypeError @@ -28,6 +32,7 @@ from .base import BaseView from .exceptions import HTTPException +from .parse_content_type import parse_content_type from .types import FormData, HTTPMethod, QueryParams from .typevars import Context, Request, Response, RootValue, SubResponse @@ -300,14 +305,13 @@ async def parse_multipart_subscriptions( async def parse_http_body( self, request: AsyncHTTPRequestAdapter ) -> GraphQLRequestData: - content_type = request.content_type or "" + content_type, params = parse_content_type(request.content_type or "") - if "application/json" in content_type: + if content_type == "application/json": data = self.parse_json(await request.get_body()) - elif content_type.startswith("multipart/form-data"): + elif content_type == "multipart/form-data": data = await self.parse_multipart(request) - elif content_type.startswith("multipart/mixed"): - # TODO: do a check that checks if this is a multipart subscription + elif self._is_multipart_subscriptions(content_type, params): data = await self.parse_multipart_subscriptions(request) elif request.method == "GET": data = self.parse_query_params(request.query_params) diff --git a/strawberry/http/base.py b/strawberry/http/base.py index 255027c050..e7308725b7 100644 --- a/strawberry/http/base.py +++ b/strawberry/http/base.py @@ -76,3 +76,16 @@ def graphql_ide_html(self) -> str: replace_variables=self._ide_replace_variables, graphql_ide=self.graphql_ide, ) + + def _is_multipart_subscriptions( + self, content_type: str, params: Dict[str, str] + ) -> bool: + if content_type != "multipart/mixed": + return False + + if params.get("boundary") != "graphql": + return False + + return tuple( + part.strip() for part in params.get("subscriptionspec", "").split(",") + ) == ("1.0", "application/json") diff --git a/strawberry/http/parse_content_type.py b/strawberry/http/parse_content_type.py new file mode 100644 index 0000000000..424d434964 --- /dev/null +++ b/strawberry/http/parse_content_type.py @@ -0,0 +1,17 @@ +from email.message import Message +from typing import Dict, Tuple + + +def parse_content_type(content_type: str) -> Tuple[str, Dict[str, str]]: + """Parse a content type header into a mime-type and a dictionary of parameters.""" + + email = Message() + email["content-type"] = content_type + + params = email.get_params() + + assert params + + mime_type, _ = params.pop(0) + + return mime_type, dict(params) diff --git a/strawberry/http/sync_base_view.py b/strawberry/http/sync_base_view.py index 34db2c0ed4..36fb5480d4 100644 --- a/strawberry/http/sync_base_view.py +++ b/strawberry/http/sync_base_view.py @@ -16,7 +16,11 @@ from strawberry import UNSET from strawberry.exceptions import MissingQueryError from strawberry.file_uploads.utils import replace_placeholders_with_files -from strawberry.http import GraphQLHTTPResponse, GraphQLRequestData, process_result +from strawberry.http import ( + GraphQLHTTPResponse, + GraphQLRequestData, + process_result, +) from strawberry.http.ides import GraphQL_IDE from strawberry.schema import BaseSchema from strawberry.schema.exceptions import InvalidOperationTypeError @@ -25,6 +29,7 @@ from .base import BaseView from .exceptions import HTTPException +from .parse_content_type import parse_content_type from .types import HTTPMethod, QueryParams from .typevars import Context, Request, Response, RootValue, SubResponse @@ -144,14 +149,18 @@ def parse_multipart(self, request: SyncHTTPRequestAdapter) -> Dict[str, str]: raise HTTPException(400, "File(s) missing in form data") from e def parse_http_body(self, request: SyncHTTPRequestAdapter) -> GraphQLRequestData: - content_type = request.content_type or "" + content_type, params = parse_content_type(request.content_type or "") - if "application/json" in content_type: + if content_type == "application/json": data = self.parse_json(request.body) - elif content_type.startswith("multipart/form-data"): + elif content_type == "multipart/form-data": data = self.parse_multipart(request) elif request.method == "GET": data = self.parse_query_params(request.query_params) + elif self._is_multipart_subscriptions(content_type, params): + raise HTTPException( + 400, "Multipart subcriptions are not supported in sync mode" + ) else: raise HTTPException(400, "Unsupported content type") diff --git a/strawberry/quart/views.py b/strawberry/quart/views.py index e817194a9f..1e1ade4381 100644 --- a/strawberry/quart/views.py +++ b/strawberry/quart/views.py @@ -1,6 +1,6 @@ import warnings from collections.abc import Mapping -from typing import TYPE_CHECKING, Optional, cast +from typing import TYPE_CHECKING, AsyncGenerator, Callable, Optional, cast from quart import Request, Response, request from quart.views import View @@ -102,3 +102,11 @@ async def dispatch_request(self) -> "ResponseReturnValue": # type: ignore response=e.reason, status=e.status_code, ) + + async def create_multipart_response( + self, stream: Callable[[], AsyncGenerator[str, None]], sub_response: Response + ) -> Response: + return stream(), 200, { # type: ignore + "Transfer-Encoding": "chunked", + "Content-type": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json", + } diff --git a/tests/http/test_multipart_subscription.py b/tests/http/test_multipart_subscription.py index f4219b5f0d..5f7aa1906e 100644 --- a/tests/http/test_multipart_subscription.py +++ b/tests/http/test_multipart_subscription.py @@ -56,9 +56,7 @@ async def test_multipart_subscription(http_client: HttpClient): "query": 'subscription { echo(message: "Hello world", delay: 0.2) }', }, headers={ - # TODO: fix headers :) second one is for django "content-type": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json", - "CONTENT_TYPE": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json", }, ) diff --git a/tests/http/test_parse_content_type.py b/tests/http/test_parse_content_type.py new file mode 100644 index 0000000000..4ae0017e40 --- /dev/null +++ b/tests/http/test_parse_content_type.py @@ -0,0 +1,49 @@ +from typing import Dict, Tuple + +import pytest + +from strawberry.http.parse_content_type import parse_content_type + + +@pytest.mark.parametrize( + ("content_type", "expected"), + [ # type: ignore + ("application/json", ("application/json", {})), + ("", ("", {})), + ("application/json; charset=utf-8", ("application/json", {"charset": "utf-8"})), + ( + "application/json; charset=utf-8; boundary=foobar", + ("application/json", {"charset": "utf-8", "boundary": "foobar"}), + ), + ( + "application/json; boundary=foobar; charset=utf-8", + ("application/json", {"boundary": "foobar", "charset": "utf-8"}), + ), + ( + "application/json; boundary=foobar", + ("application/json", {"boundary": "foobar"}), + ), + ( + "application/json; boundary=foobar; charset=utf-8; foo=bar", + ( + "application/json", + {"boundary": "foobar", "charset": "utf-8", "foo": "bar"}, + ), + ), + ( + 'multipart/mixed; boundary="graphql"; subscriptionSpec=1.0, application/json', + ( + "multipart/mixed", + { + "boundary": "graphql", + "subscriptionspec": "1.0, application/json", + }, + ), + ), + ], +) +async def test_parse_content_type( + content_type: str, + expected: Tuple[str, Dict[str, str]], +): + assert parse_content_type(content_type) == expected From 2df7ce57f4deecbe86544f9ad840a569aa464b3d Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Fri, 10 Nov 2023 09:26:46 +0100 Subject: [PATCH 30/57] Run channels tests, even if they are broken --- tests/http/test_multipart_subscription.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/http/test_multipart_subscription.py b/tests/http/test_multipart_subscription.py index 5f7aa1906e..39431d3721 100644 --- a/tests/http/test_multipart_subscription.py +++ b/tests/http/test_multipart_subscription.py @@ -20,15 +20,12 @@ def http_client(http_client_class: Type[HttpClient]) -> HttpClient: pytest.skip(reason="(sync) DjangoHttpClient doesn't support subscriptions") with contextlib.suppress(ImportError): - from .clients.channels import ChannelsHttpClient, SyncChannelsHttpClient + from .clients.channels import SyncChannelsHttpClient # TODO: why do we have a sync channels client? if http_client_class is SyncChannelsHttpClient: pytest.skip(reason="SyncChannelsHttpClient doesn't support subscriptions") - if http_client_class is ChannelsHttpClient: - pytest.xfail(reason="ChannelsHttpClient is broken at the moment") - with contextlib.suppress(ImportError): from .clients.async_flask import AsyncFlaskHttpClient from .clients.flask import FlaskHttpClient From 2c54b09702793fef8635b1d6953123027e02aefc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 16 Jan 2024 14:20:28 +0000 Subject: [PATCH 31/57] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- strawberry/quart/views.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/strawberry/quart/views.py b/strawberry/quart/views.py index 1e1ade4381..1466e55e92 100644 --- a/strawberry/quart/views.py +++ b/strawberry/quart/views.py @@ -106,7 +106,11 @@ async def dispatch_request(self) -> "ResponseReturnValue": # type: ignore async def create_multipart_response( self, stream: Callable[[], AsyncGenerator[str, None]], sub_response: Response ) -> Response: - return stream(), 200, { # type: ignore - "Transfer-Encoding": "chunked", - "Content-type": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json", - } + return ( + stream(), + 200, + { # type: ignore + "Transfer-Encoding": "chunked", + "Content-type": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json", + }, + ) From 0ca08bbeea869ecf7af474aea11c43470d28c8a6 Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Tue, 16 Jan 2024 15:26:32 +0100 Subject: [PATCH 32/57] Fix lint --- strawberry/django/views.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/strawberry/django/views.py b/strawberry/django/views.py index 3b42a0e09b..91a1901248 100644 --- a/strawberry/django/views.py +++ b/strawberry/django/views.py @@ -167,7 +167,7 @@ def __init__( def create_response( self, response_data: GraphQLHTTPResponse, sub_response: HttpResponse ) -> HttpResponseBase: - data = self.encode_json(response_data) # type: ignore + data = self.encode_json(response_data) response = HttpResponse( data, From 2b6b6034b39458172754a85c667ca346f72ce068 Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Tue, 16 Jan 2024 16:01:59 +0100 Subject: [PATCH 33/57] Fix import --- strawberry/django/views.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/strawberry/django/views.py b/strawberry/django/views.py index 91a1901248..8647424093 100644 --- a/strawberry/django/views.py +++ b/strawberry/django/views.py @@ -18,11 +18,11 @@ from django.http import ( HttpRequest, HttpResponse, - HttpResponseBase, HttpResponseNotAllowed, JsonResponse, StreamingHttpResponse, ) +from django.http.response import HttpResponseBase from django.template import RequestContext, Template from django.template.exceptions import TemplateDoesNotExist from django.template.loader import render_to_string From adb46604031b664a9593b0b9c89fbd45df310adf Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Tue, 16 Jan 2024 16:15:59 +0100 Subject: [PATCH 34/57] await cancelled tasks --- strawberry/http/async_base_view.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/strawberry/http/async_base_view.py b/strawberry/http/async_base_view.py index 2a2c436c8a..805c9c0042 100644 --- a/strawberry/http/async_base_view.py +++ b/strawberry/http/async_base_view.py @@ -1,5 +1,6 @@ import abc import asyncio +import contextlib import json from typing import ( Any, @@ -262,23 +263,30 @@ async def merged() -> AsyncGenerator[str, None]: heartbeat_task = asyncio.create_task(heartbeat()) task = asyncio.create_task(drain()) - def cancel_tasks(): + async def cancel_tasks(): nonlocal cancelling cancelling = True task.cancel() + + with contextlib.suppress(asyncio.CancelledError): + await task + heartbeat_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await heartbeat_task + try: while not task.done(): raised, data = await queue.get() if raised: - cancel_tasks() + await cancel_tasks() raise data yield data finally: - cancel_tasks() + await cancel_tasks() return merged From 48e52e93fff32060fe0be650244907fd3f6ff6cb Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Tue, 16 Jan 2024 16:19:34 +0100 Subject: [PATCH 35/57] Some refactoring --- tests/http/clients/async_django.py | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/tests/http/clients/async_django.py b/tests/http/clients/async_django.py index 61b72916d4..0e8bfca0ed 100644 --- a/tests/http/clients/async_django.py +++ b/tests/http/clients/async_django.py @@ -48,23 +48,21 @@ async def _do_request(self, request: RequestFactory) -> Response: try: response = await view(request) except Http404: - return Response( - status_code=404, data=b"Not found", headers=response.headers - ) + return Response(status_code=404, data=b"Not found", headers={}) except (BadRequest, SuspiciousOperation) as e: return Response( - status_code=400, data=e.args[0].encode(), headers=response.headers + status_code=400, + data=e.args[0].encode(), + headers={}, ) + data = ( + response.streaming_content + if isinstance(response, StreamingHttpResponse) + else response.content + ) - if isinstance(response, StreamingHttpResponse): - return Response( - status_code=response.status_code, - data=response.streaming_content, - headers=response.headers, - ) - else: - return Response( - status_code=response.status_code, - data=response.content, - headers=response.headers, - ) + return Response( + status_code=response.status_code, + data=data, + headers=response.headers, + ) From f56149b33b7b57f595cd0d96821f8a6cbe18c0f8 Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Tue, 16 Jan 2024 17:36:20 +0100 Subject: [PATCH 36/57] Remove stale comment --- strawberry/http/async_base_view.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/strawberry/http/async_base_view.py b/strawberry/http/async_base_view.py index 805c9c0042..d41f5c2d44 100644 --- a/strawberry/http/async_base_view.py +++ b/strawberry/http/async_base_view.py @@ -236,9 +236,6 @@ def _stream_with_heartbeat( ) -> Callable[[], AsyncGenerator[str, None]]: """Adds a heartbeat to the stream, to prevent the connection from closing when there are no messages being sent.""" - # TODO: handle errors - # TODO: should we do this more efficiently? and only send the heartbeat when - # 5 seconds have passed after the last message? (apollo router doesn't seem to do this) queue = asyncio.Queue[Tuple[bool, Any]](1) cancelling = False From f9d313781ee91424085b64ac0244b61e7dabc47d Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Tue, 16 Jan 2024 17:38:56 +0100 Subject: [PATCH 37/57] update type --- strawberry/channels/handlers/http_handler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/strawberry/channels/handlers/http_handler.py b/strawberry/channels/handlers/http_handler.py index f73adb8c64..5afcfbe5fe 100644 --- a/strawberry/channels/handlers/http_handler.py +++ b/strawberry/channels/handlers/http_handler.py @@ -330,10 +330,10 @@ def render_graphql_ide(self, request: ChannelsRequest) -> ChannelsResponse: # handlers in a threadpool. Check SyncConsumer's documentation for more info: # https://github.com/django/channels/blob/main/channels/consumer.py#L104 @database_sync_to_async # pyright: ignore[reportIncompatibleMethodOverride] - def run( + def run( # type: ignore[override] self, request: ChannelsRequest, context: Optional[Context] = UNSET, root_value: Optional[RootValue] = UNSET, - ) -> ChannelsResponse: # pyright: ignore + ) -> ChannelsResponse | MultipartChannelsResponse: return super().run(request, context, root_value) From 0ea4febb8d50ee215a85170a8fa49b409e10192c Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Tue, 16 Jan 2024 17:53:56 +0100 Subject: [PATCH 38/57] Pass request down to creat multipart response --- strawberry/aiohttp/views.py | 6 +++--- strawberry/asgi/__init__.py | 5 ++++- strawberry/channels/handlers/http_handler.py | 9 +++++++-- strawberry/django/views.py | 5 ++++- strawberry/fastapi/router.py | 5 ++++- strawberry/http/async_base_view.py | 7 +++++-- strawberry/quart/views.py | 5 ++++- strawberry/sanic/views.py | 1 + tests/http/test_multipart_subscription.py | 20 +++++++++++++++----- 9 files changed, 47 insertions(+), 16 deletions(-) diff --git a/strawberry/aiohttp/views.py b/strawberry/aiohttp/views.py index cd0508a644..1bb239dee1 100644 --- a/strawberry/aiohttp/views.py +++ b/strawberry/aiohttp/views.py @@ -140,8 +140,6 @@ async def __call__(self, request: web.Request) -> web.StreamResponse: if not ws_test.ok: try: - # TODO: pass this down from run to multipart thingy - self.request = request return await self.run(request=request) except HTTPException as e: return web.Response( @@ -191,6 +189,7 @@ def create_response( async def create_multipart_response( self, + request: web.Request, stream: Callable[[], AsyncGenerator[str, None]], sub_response: web.Response, ) -> web.StreamResponse: @@ -204,10 +203,11 @@ async def create_multipart_response( reason="OK", ) - await response.prepare(self.request) + await response.prepare(request) async for data in stream(): await response.write(data.encode()) await response.write_eof() + return response diff --git a/strawberry/asgi/__init__.py b/strawberry/asgi/__init__.py index e82025fb77..5b2829e891 100644 --- a/strawberry/asgi/__init__.py +++ b/strawberry/asgi/__init__.py @@ -222,7 +222,10 @@ def create_response( return response async def create_multipart_response( - self, stream: Callable[[], AsyncIterator[str]], sub_response: Response + self, + request: Request, + stream: Callable[[], AsyncIterator[str]], + sub_response: Response, ) -> Response: return StreamingResponse( stream(), diff --git a/strawberry/channels/handlers/http_handler.py b/strawberry/channels/handlers/http_handler.py index 5afcfbe5fe..4d8fe0a01f 100644 --- a/strawberry/channels/handlers/http_handler.py +++ b/strawberry/channels/handlers/http_handler.py @@ -19,6 +19,7 @@ Optional, Union, ) +from typing_extensions import assert_never from urllib.parse import parse_qs from django.conf import settings @@ -212,15 +213,18 @@ async def handle(self, body: bytes) -> None: await self.send_headers(headers=response.headers) async for chunk in response.stream(): - # TODO: we should change more body await self.send_body(chunk.encode("utf-8"), more_body=True) - else: + await self.send_body(b"", more_body=False) + + elif isinstance(response, ChannelsResponse): await self.send_response( response.status, response.content, headers=response.headers, ) + else: + assert_never(response) except HTTPException as e: await self.send_response(e.status_code, e.reason.encode()) @@ -275,6 +279,7 @@ async def get_sub_response(self, request: ChannelsRequest) -> TemporalResponse: async def create_multipart_response( self, + request: ChannelsRequest, stream: Callable[[], AsyncGenerator[str, None]], sub_response: TemporalResponse, ) -> MultipartChannelsResponse: diff --git a/strawberry/django/views.py b/strawberry/django/views.py index 8647424093..a3be96ec5f 100644 --- a/strawberry/django/views.py +++ b/strawberry/django/views.py @@ -186,7 +186,10 @@ def create_response( return response async def create_multipart_response( - self, stream: Callable[[], AsyncIterator[Any]], sub_response: HttpResponse + self, + request: HttpRequest, + stream: Callable[[], AsyncIterator[Any]], + sub_response: HttpResponse, ) -> HttpResponseBase: return StreamingHttpResponse( streaming_content=stream(), diff --git a/strawberry/fastapi/router.py b/strawberry/fastapi/router.py index 728f7fc587..08d2a019e0 100644 --- a/strawberry/fastapi/router.py +++ b/strawberry/fastapi/router.py @@ -326,7 +326,10 @@ def create_response( return response async def create_multipart_response( - self, stream: Callable[[], AsyncIterator[str]], sub_response: Response + self, + request: Request, + stream: Callable[[], AsyncIterator[str]], + sub_response: Response, ) -> Response: return StreamingResponse( stream(), diff --git a/strawberry/http/async_base_view.py b/strawberry/http/async_base_view.py index d41f5c2d44..9ce3239401 100644 --- a/strawberry/http/async_base_view.py +++ b/strawberry/http/async_base_view.py @@ -105,7 +105,10 @@ async def render_graphql_ide(self, request: Request) -> Response: ... async def create_multipart_response( - self, stream: Callable[[], AsyncGenerator[str, None]], sub_response: SubResponse + self, + request: Request, + stream: Callable[[], AsyncGenerator[str, None]], + sub_response: SubResponse, ) -> Response: raise ValueError("Multipart responses are not supported") @@ -210,7 +213,7 @@ async def run( if isinstance(result, SubscriptionExecutionResult): stream = self._get_stream(request, result) - return await self.create_multipart_response(stream, sub_response) + return await self.create_multipart_response(request, stream, sub_response) response_data = await self.process_result(request=request, result=result) diff --git a/strawberry/quart/views.py b/strawberry/quart/views.py index 1466e55e92..21cfe83ca4 100644 --- a/strawberry/quart/views.py +++ b/strawberry/quart/views.py @@ -104,7 +104,10 @@ async def dispatch_request(self) -> "ResponseReturnValue": # type: ignore ) async def create_multipart_response( - self, stream: Callable[[], AsyncGenerator[str, None]], sub_response: Response + self, + request: Request, + stream: Callable[[], AsyncGenerator[str, None]], + sub_response: Response, ) -> Response: return ( stream(), diff --git a/strawberry/sanic/views.py b/strawberry/sanic/views.py index 8c99b67178..69255acf68 100644 --- a/strawberry/sanic/views.py +++ b/strawberry/sanic/views.py @@ -187,6 +187,7 @@ async def get(self, request: Request) -> HTTPResponse: # type: ignore[override] async def create_multipart_response( self, + request: Request, stream: Callable[[], AsyncGenerator[str, None]], sub_response: TemporalResponse, ) -> HTTPResponse: diff --git a/tests/http/test_multipart_subscription.py b/tests/http/test_multipart_subscription.py index 39431d3721..b33a6c37da 100644 --- a/tests/http/test_multipart_subscription.py +++ b/tests/http/test_multipart_subscription.py @@ -17,30 +17,40 @@ def http_client(http_client_class: Type[HttpClient]) -> HttpClient: from .clients.django import DjangoHttpClient if http_client_class is DjangoHttpClient: - pytest.skip(reason="(sync) DjangoHttpClient doesn't support subscriptions") + pytest.skip( + reason="(sync) DjangoHttpClient doesn't support multipart subscriptions" + ) with contextlib.suppress(ImportError): from .clients.channels import SyncChannelsHttpClient # TODO: why do we have a sync channels client? if http_client_class is SyncChannelsHttpClient: - pytest.skip(reason="SyncChannelsHttpClient doesn't support subscriptions") + pytest.skip( + reason="SyncChannelsHttpClient doesn't support multipart subscriptions" + ) with contextlib.suppress(ImportError): from .clients.async_flask import AsyncFlaskHttpClient from .clients.flask import FlaskHttpClient if http_client_class is FlaskHttpClient: - pytest.skip(reason="FlaskHttpClient doesn't support subscriptions") + pytest.skip( + reason="FlaskHttpClient doesn't support multipart subscriptions" + ) if http_client_class is AsyncFlaskHttpClient: - pytest.xfail(reason="AsyncFlaskHttpClient doesn't support subscriptions") + pytest.xfail( + reason="AsyncFlaskHttpClient doesn't support multipart subscriptions" + ) with contextlib.suppress(ImportError): from .clients.chalice import ChaliceHttpClient if http_client_class is ChaliceHttpClient: - pytest.skip(reason="ChaliceHttpClient doesn't support subscriptions") + pytest.skip( + reason="ChaliceHttpClient doesn't support multipart subscriptions" + ) return http_client_class() From abd03946f39306d24e289d49c18ce1a731dbe759 Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Tue, 16 Jan 2024 17:59:58 +0100 Subject: [PATCH 39/57] Fix tests with lowercase headers --- tests/http/clients/base.py | 20 +++++++++++++++++--- tests/http/test_query.py | 2 +- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/tests/http/clients/base.py b/tests/http/clients/base.py index 54c897546e..2dc579484f 100644 --- a/tests/http/clients/base.py +++ b/tests/http/clients/base.py @@ -3,6 +3,7 @@ import json import logging from dataclasses import dataclass +from functools import cached_property from io import BytesIO from typing import ( Any, @@ -32,12 +33,25 @@ class Response: status_code: int data: Union[bytes, AsyncIterable[bytes]] - headers: Mapping[str, str] + + def __init__( + self, + status_code: int, + data: Union[bytes, AsyncIterable[bytes]], + *, + headers: Optional[Dict[str, str]] = None, + ) -> None: + self.status_code = status_code + self.data = data + self._headers = headers or {} + + @cached_property + def headers(self) -> Mapping[str, str]: + return {k.lower(): v for k, v in self._headers.items()} @property def is_multipart(self) -> bool: - # TODO: check casing - return self.headers.get("Content-type", "").startswith("multipart/mixed") + return self.headers.get("content-type", "").startswith("multipart/mixed") @property def text(self) -> str: diff --git a/tests/http/test_query.py b/tests/http/test_query.py index f60a122d8c..d7910fd220 100644 --- a/tests/http/test_query.py +++ b/tests/http/test_query.py @@ -216,4 +216,4 @@ async def test_updating_headers( assert response.status_code == 200 assert response.json["data"] == {"setHeader": "Jake"} - assert response.headers["X-Name"] == "Jake" + assert response.headers["x-name"] == "Jake" From 93928e8f95624178f707e17a28d11412ca7979ce Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Tue, 16 Jan 2024 18:55:33 +0100 Subject: [PATCH 40/57] Get support --- strawberry/http/async_base_view.py | 3 +++ strawberry/types/graphql.py | 6 +++++- tests/http/clients/base.py | 12 +++++++----- tests/http/test_multipart_subscription.py | 15 ++++++++------- 4 files changed, 23 insertions(+), 13 deletions(-) diff --git a/strawberry/http/async_base_view.py b/strawberry/http/async_base_view.py index 9ce3239401..658caba2e4 100644 --- a/strawberry/http/async_base_view.py +++ b/strawberry/http/async_base_view.py @@ -308,6 +308,9 @@ async def stream(): async def parse_multipart_subscriptions( self, request: AsyncHTTPRequestAdapter ) -> Dict[str, str]: + if request.method == "GET": + return self.parse_query_params(request.query_params) + return self.parse_json(await request.get_body()) async def parse_http_body( diff --git a/strawberry/types/graphql.py b/strawberry/types/graphql.py index de58492e1e..35417aa109 100644 --- a/strawberry/types/graphql.py +++ b/strawberry/types/graphql.py @@ -15,7 +15,11 @@ class OperationType(enum.Enum): @staticmethod def from_http(method: HTTPMethod) -> Set[OperationType]: if method == "GET": - return {OperationType.QUERY} + return { + OperationType.QUERY, + # subscriptions are supported via GET in the multipart protocol + OperationType.SUBSCRIPTION, + } if method == "POST": return { diff --git a/tests/http/clients/base.py b/tests/http/clients/base.py index 2dc579484f..a20d9d0c92 100644 --- a/tests/http/clients/base.py +++ b/tests/http/clients/base.py @@ -163,16 +163,18 @@ def _get_headers( headers: Optional[Dict[str, str]], files: Optional[Dict[str, BytesIO]], ) -> Dict[str, str]: - addition_headers = {} + additional_headers = {} + headers = headers or {} - content_type = None + # TODO: fix case sensitivity + content_type = headers.get("content-type") - if method == "post" and not files: + if not content_type and method == "post" and not files: content_type = "application/json" - addition_headers = {"Content-Type": content_type} if content_type else {} + additional_headers = {"Content-Type": content_type} if content_type else {} - return addition_headers if headers is None else {**addition_headers, **headers} + return {**additional_headers, **headers} def _build_body( self, diff --git a/tests/http/test_multipart_subscription.py b/tests/http/test_multipart_subscription.py index b33a6c37da..d9cb289d8c 100644 --- a/tests/http/test_multipart_subscription.py +++ b/tests/http/test_multipart_subscription.py @@ -1,5 +1,6 @@ import contextlib from typing import Type +from typing_extensions import Literal import pytest @@ -55,13 +56,13 @@ def http_client(http_client_class: Type[HttpClient]) -> HttpClient: return http_client_class() -# TODO: do multipart subscriptions work on both GET and POST? -async def test_multipart_subscription(http_client: HttpClient): - response = await http_client.post( - url="/graphql", - json={ - "query": 'subscription { echo(message: "Hello world", delay: 0.2) }', - }, +@pytest.mark.parametrize("method", ["get", "post"]) +async def test_multipart_subscription( + http_client: HttpClient, method: Literal["get", "post"] +): + response = await http_client.query( + method=method, + query='subscription { echo(message: "Hello world", delay: 0.2) }', headers={ "content-type": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json", }, From 25af971284b13d1624b7afd412df017036c1e30a Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Tue, 16 Jan 2024 19:14:01 +0100 Subject: [PATCH 41/57] Some release notes --- RELEASE.md | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/RELEASE.md b/RELEASE.md index de7cc70ac5..749e0fcc41 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,18 @@ Release type: minor -TODO :) +This release adds support for multipart subscriptions in almost all\* +integrations! + +[Multipart subcriptions](https://www.apollographql.com/docs/router/executing-operations/subscription-multipart-protocol/) +are a new protocol from Apollo GraphQL, built on the +[Incremental Delivery over HTTP spec](https://github.com/graphql/graphql-over-http/blob/main/rfcs/IncrementalDelivery.md), +which is also used for `@defer` and `@stream`. + +The main advantage of this protocol is that when using the Apollo Client +libraries you don't need to install any additional dependency, but in future +this feature will make it easier for us to implement `@defer` and `@stream` + +Also, this means that you don't need to use Django Channels for subscription, +since this protocol is based on HTTP we don't need to use websockets. + +- Flask, Chalice and the sync Django integration don't support this yet. From 1667d1916c1fa043feda27cc3e3059b0a38041b2 Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Wed, 24 Jan 2024 15:01:31 +0100 Subject: [PATCH 42/57] Litestar support --- pyproject.toml | 6 ++++++ strawberry/http/base.py | 4 +--- strawberry/litestar/controller.py | 25 ++++++++++++++++++++++++- 3 files changed, 31 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b4c8fdb8f5..f8284dbec2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -174,6 +174,12 @@ filterwarnings = [ # ignoring the text instead of the whole warning because we'd # get an error when django is not installed "ignore:The default value of USE_TZ", + "ignore::DeprecationWarning:pydantic_openapi_schema.*", + "ignore::DeprecationWarning:graphql.*", + "ignore::DeprecationWarning:websockets.*", + "ignore::DeprecationWarning:pydantic.*", + "ignore::UserWarning:pydantic.*", + "ignore::DeprecationWarning:pkg_resources.*", ] [tool.autopub] diff --git a/strawberry/http/base.py b/strawberry/http/base.py index 5dc8fdc42f..2478c55574 100644 --- a/strawberry/http/base.py +++ b/strawberry/http/base.py @@ -86,6 +86,4 @@ def _is_multipart_subscriptions( if params.get("boundary") != "graphql": return False - return tuple( - part.strip() for part in params.get("subscriptionspec", "").split(",") - ) == ("1.0", "application/json") + return params.get("subscriptionspec") == "1.0" diff --git a/strawberry/litestar/controller.py b/strawberry/litestar/controller.py index 5c9877427b..b388086fc3 100644 --- a/strawberry/litestar/controller.py +++ b/strawberry/litestar/controller.py @@ -7,6 +7,8 @@ from typing import ( TYPE_CHECKING, Any, + AsyncIterator, + Callable, Dict, FrozenSet, List, @@ -34,6 +36,7 @@ from litestar.background_tasks import BackgroundTasks from litestar.di import Provide from litestar.exceptions import NotFoundException, ValidationException +from litestar.response.streaming import Stream from litestar.status_codes import HTTP_200_OK from strawberry.exceptions import InvalidCustomContext from strawberry.http.async_base_view import AsyncBaseHTTPView, AsyncHTTPRequestAdapter @@ -183,7 +186,13 @@ def headers(self) -> Mapping[str, str]: @property def content_type(self) -> Optional[str]: - return self.request.content_type[0] + content_type, params = self.request.content_type + + # combine content type and params + if params: + content_type += "; " + "; ".join(f"{k}={v}" for k, v in params.items()) + + return content_type async def get_body(self) -> bytes: return await self.request.body() @@ -271,6 +280,20 @@ def create_response( return response + async def create_multipart_response( + self, + request: Request, + stream: Callable[[], AsyncIterator[str]], + sub_response: Response, + ) -> Response: + return Stream( + stream(), + headers={ + "Transfer-Encoding": "chunked", + "Content-type": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json", + }, + ) + @get(raises=[ValidationException, NotFoundException]) async def handle_http_get( self, From b15b5ac811037a476052b61b809bdeb3f630dfd1 Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Fri, 26 Jan 2024 18:09:38 +0100 Subject: [PATCH 43/57] Add docs --- docs/README.md | 1 + docs/general/multipart-subscriptions.md | 37 +++++++++++++++++++++++++ 2 files changed, 38 insertions(+) create mode 100644 docs/general/multipart-subscriptions.md diff --git a/docs/README.md b/docs/README.md index 1174034d1a..5a65aeb236 100644 --- a/docs/README.md +++ b/docs/README.md @@ -8,6 +8,7 @@ - [Queries](./general/queries.md) - [Mutations](./general/mutations.md) - [Subscriptions](./general/subscriptions.md) +- [Multipart Subscriptions](./general/multipart-subscriptions.md) - [Why](./general/why.md) - [Breaking changes](./breaking-changes.md) - [Upgrading Strawberry](./general/upgrades.md) diff --git a/docs/general/multipart-subscriptions.md b/docs/general/multipart-subscriptions.md new file mode 100644 index 0000000000..ef7157432b --- /dev/null +++ b/docs/general/multipart-subscriptions.md @@ -0,0 +1,37 @@ +--- +title: Multipart subscriptions +--- + +# Multipart subscriptions + +Strawberry supports subscription over multipart responses. This is an +[alternative protocol](https://www.apollographql.com/docs/router/executing-operations/subscription-multipart-protocol/) +created by [Apollo](https://www.apollographql.com/) to support subscriptions +over HTTP, and it is supported by default by Apollo Client. + +# Support + +We support multipart subscriptions out of the box in the following HTTP +libraries: + +- Django (only in the Async view) +- ASGI +- Litestar +- FastAPI +- AioHTTP + +# Usage + +Multipart subscriptions are automatically enabled when using Subscription, so no +additional configuration is required. + +# Limitations + +At the moment, we don't support the following features: + +- Changing the status code of the response +- Changing the headers of the response + +We might add support for these features in the future, but it's clear how they +would work in the context of a subscription. If you have any ideas feel free to +reach out on our [discord server](https://strawberry.rocks/discord). From 49352018ffdbdd06a40ba2f42c8f1ce06c9592ec Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Fri, 26 Jan 2024 18:12:02 +0100 Subject: [PATCH 44/57] Remove subresponse --- docs/general/multipart-subscriptions.md | 1 + strawberry/aiohttp/views.py | 2 -- strawberry/asgi/__init__.py | 1 - strawberry/channels/handlers/http_handler.py | 1 - strawberry/django/views.py | 1 - strawberry/fastapi/router.py | 1 - strawberry/http/async_base_view.py | 3 +-- strawberry/litestar/controller.py | 1 - strawberry/quart/views.py | 1 - strawberry/sanic/views.py | 1 - 10 files changed, 2 insertions(+), 11 deletions(-) diff --git a/docs/general/multipart-subscriptions.md b/docs/general/multipart-subscriptions.md index ef7157432b..6da191ca72 100644 --- a/docs/general/multipart-subscriptions.md +++ b/docs/general/multipart-subscriptions.md @@ -19,6 +19,7 @@ libraries: - Litestar - FastAPI - AioHTTP +- Quart # Usage diff --git a/strawberry/aiohttp/views.py b/strawberry/aiohttp/views.py index 1bb239dee1..033586dae0 100644 --- a/strawberry/aiohttp/views.py +++ b/strawberry/aiohttp/views.py @@ -191,9 +191,7 @@ async def create_multipart_response( self, request: web.Request, stream: Callable[[], AsyncGenerator[str, None]], - sub_response: web.Response, ) -> web.StreamResponse: - # TODO: use sub response response = web.StreamResponse( status=200, headers={ diff --git a/strawberry/asgi/__init__.py b/strawberry/asgi/__init__.py index 5b2829e891..271f3b71a3 100644 --- a/strawberry/asgi/__init__.py +++ b/strawberry/asgi/__init__.py @@ -225,7 +225,6 @@ async def create_multipart_response( self, request: Request, stream: Callable[[], AsyncIterator[str]], - sub_response: Response, ) -> Response: return StreamingResponse( stream(), diff --git a/strawberry/channels/handlers/http_handler.py b/strawberry/channels/handlers/http_handler.py index 4d8fe0a01f..87309b9659 100644 --- a/strawberry/channels/handlers/http_handler.py +++ b/strawberry/channels/handlers/http_handler.py @@ -281,7 +281,6 @@ async def create_multipart_response( self, request: ChannelsRequest, stream: Callable[[], AsyncGenerator[str, None]], - sub_response: TemporalResponse, ) -> MultipartChannelsResponse: # TODO: sub response return MultipartChannelsResponse(stream=stream) diff --git a/strawberry/django/views.py b/strawberry/django/views.py index a3be96ec5f..22912db83b 100644 --- a/strawberry/django/views.py +++ b/strawberry/django/views.py @@ -189,7 +189,6 @@ async def create_multipart_response( self, request: HttpRequest, stream: Callable[[], AsyncIterator[Any]], - sub_response: HttpResponse, ) -> HttpResponseBase: return StreamingHttpResponse( streaming_content=stream(), diff --git a/strawberry/fastapi/router.py b/strawberry/fastapi/router.py index 08d2a019e0..85f6558585 100644 --- a/strawberry/fastapi/router.py +++ b/strawberry/fastapi/router.py @@ -329,7 +329,6 @@ async def create_multipart_response( self, request: Request, stream: Callable[[], AsyncIterator[str]], - sub_response: Response, ) -> Response: return StreamingResponse( stream(), diff --git a/strawberry/http/async_base_view.py b/strawberry/http/async_base_view.py index 658caba2e4..de9e8578c0 100644 --- a/strawberry/http/async_base_view.py +++ b/strawberry/http/async_base_view.py @@ -108,7 +108,6 @@ async def create_multipart_response( self, request: Request, stream: Callable[[], AsyncGenerator[str, None]], - sub_response: SubResponse, ) -> Response: raise ValueError("Multipart responses are not supported") @@ -213,7 +212,7 @@ async def run( if isinstance(result, SubscriptionExecutionResult): stream = self._get_stream(request, result) - return await self.create_multipart_response(request, stream, sub_response) + return await self.create_multipart_response(request, stream) response_data = await self.process_result(request=request, result=result) diff --git a/strawberry/litestar/controller.py b/strawberry/litestar/controller.py index b388086fc3..ea7937f25c 100644 --- a/strawberry/litestar/controller.py +++ b/strawberry/litestar/controller.py @@ -284,7 +284,6 @@ async def create_multipart_response( self, request: Request, stream: Callable[[], AsyncIterator[str]], - sub_response: Response, ) -> Response: return Stream( stream(), diff --git a/strawberry/quart/views.py b/strawberry/quart/views.py index 21cfe83ca4..c1601a6683 100644 --- a/strawberry/quart/views.py +++ b/strawberry/quart/views.py @@ -107,7 +107,6 @@ async def create_multipart_response( self, request: Request, stream: Callable[[], AsyncGenerator[str, None]], - sub_response: Response, ) -> Response: return ( stream(), diff --git a/strawberry/sanic/views.py b/strawberry/sanic/views.py index 69255acf68..f7f44cd7ce 100644 --- a/strawberry/sanic/views.py +++ b/strawberry/sanic/views.py @@ -189,7 +189,6 @@ async def create_multipart_response( self, request: Request, stream: Callable[[], AsyncGenerator[str, None]], - sub_response: TemporalResponse, ) -> HTTPResponse: response = await self.request.respond( content_type="multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json", From 4f5f1c48c3d7cdb9c4ac245fe86fb39eda43cb20 Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Sat, 2 Mar 2024 02:29:35 +0100 Subject: [PATCH 45/57] Fix check --- strawberry/http/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/strawberry/http/base.py b/strawberry/http/base.py index 2478c55574..b69766ec58 100644 --- a/strawberry/http/base.py +++ b/strawberry/http/base.py @@ -86,4 +86,4 @@ def _is_multipart_subscriptions( if params.get("boundary") != "graphql": return False - return params.get("subscriptionspec") == "1.0" + return params.get("subscriptionspec", "").startswith("1.0") From 7ed9e0eed5983c4c363b26693b5467a0351fc0c2 Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Thu, 29 Aug 2024 19:51:39 +0200 Subject: [PATCH 46/57] Fix --- strawberry/schema/execute.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/strawberry/schema/execute.py b/strawberry/schema/execute.py index 8ccff111c5..a4b6f3161e 100644 --- a/strawberry/schema/execute.py +++ b/strawberry/schema/execute.py @@ -15,9 +15,8 @@ Union, ) -from graphql import GraphQLError, parse +from graphql import GraphQLError, parse, subscribe from graphql import execute as original_execute -from graphql.execution.subscribe import subscribe from graphql.validation import validate from strawberry.exceptions import MissingQueryError From f0cffb1cda883252adbf3aaffc01c8c49ba8551c Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Thu, 29 Aug 2024 23:50:46 +0200 Subject: [PATCH 47/57] Fix --- strawberry/http/async_base_view.py | 6 +++--- strawberry/http/sync_base_view.py | 7 ++++--- strawberry/schema/execute.py | 2 +- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/strawberry/http/async_base_view.py b/strawberry/http/async_base_view.py index dabd98958c..6791aad882 100644 --- a/strawberry/http/async_base_view.py +++ b/strawberry/http/async_base_view.py @@ -304,14 +304,14 @@ async def parse_http_body( ) -> GraphQLRequestData: content_type, params = parse_content_type(request.content_type or "") - if "application/json" in content_type: + if request.method == "GET": + data = self.parse_query_params(request.query_params) + elif "application/json" in content_type: data = self.parse_json(await request.get_body()) elif content_type == "multipart/form-data": data = await self.parse_multipart(request) elif self._is_multipart_subscriptions(content_type, params): data = await self.parse_multipart_subscriptions(request) - elif request.method == "GET": - data = self.parse_query_params(request.query_params) else: raise HTTPException(400, "Unsupported content type") diff --git a/strawberry/http/sync_base_view.py b/strawberry/http/sync_base_view.py index 5c36b94d91..f1ce7ca19a 100644 --- a/strawberry/http/sync_base_view.py +++ b/strawberry/http/sync_base_view.py @@ -138,12 +138,13 @@ def parse_multipart(self, request: SyncHTTPRequestAdapter) -> Dict[str, str]: def parse_http_body(self, request: SyncHTTPRequestAdapter) -> GraphQLRequestData: content_type, params = parse_content_type(request.content_type or "") - if "application/json" in content_type: + if request.method == "GET": + data = self.parse_query_params(request.query_params) + elif "application/json" in content_type: data = self.parse_json(request.body) + # TODO: multipart via get? elif content_type == "multipart/form-data": data = self.parse_multipart(request) - elif request.method == "GET": - data = self.parse_query_params(request.query_params) elif self._is_multipart_subscriptions(content_type, params): raise HTTPException( 400, "Multipart subcriptions are not supported in sync mode" diff --git a/strawberry/schema/execute.py b/strawberry/schema/execute.py index a4b6f3161e..674415d75d 100644 --- a/strawberry/schema/execute.py +++ b/strawberry/schema/execute.py @@ -24,6 +24,7 @@ from strawberry.schema.validation_rules.one_of import OneOfInputValidationRule from strawberry.types import ExecutionResult from strawberry.types.execution import SubscriptionExecutionResult +from strawberry.types.graphql import OperationType from .exceptions import InvalidOperationTypeError @@ -37,7 +38,6 @@ from strawberry.extensions import SchemaExtension from strawberry.types import ExecutionContext - from strawberry.types.graphql import OperationType # duplicated because of https://github.com/mkdocstrings/griffe-typingdoc/issues/7 From 10176fa4ef6a57e469bd4f8415d85643f58893ed Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 29 Aug 2024 21:52:16 +0000 Subject: [PATCH 48/57] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- strawberry/http/parse_content_type.py | 1 - 1 file changed, 1 deletion(-) diff --git a/strawberry/http/parse_content_type.py b/strawberry/http/parse_content_type.py index 424d434964..d28be1a337 100644 --- a/strawberry/http/parse_content_type.py +++ b/strawberry/http/parse_content_type.py @@ -4,7 +4,6 @@ def parse_content_type(content_type: str) -> Tuple[str, Dict[str, str]]: """Parse a content type header into a mime-type and a dictionary of parameters.""" - email = Message() email["content-type"] = content_type From 77aef3a6871cdae0ff548beca8375d05d4b35f11 Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Fri, 30 Aug 2024 10:43:14 +0200 Subject: [PATCH 49/57] Lint --- strawberry/asgi/__init__.py | 2 +- strawberry/channels/handlers/http_handler.py | 2 +- strawberry/schema/schema.py | 3 --- .../subscriptions/protocols/graphql_transport_ws/handlers.py | 2 +- 4 files changed, 3 insertions(+), 6 deletions(-) diff --git a/strawberry/asgi/__init__.py b/strawberry/asgi/__init__.py index d88628c57d..ee3df73333 100644 --- a/strawberry/asgi/__init__.py +++ b/strawberry/asgi/__init__.py @@ -223,7 +223,7 @@ def create_response( async def create_multipart_response( self, - request: Request, + request: Request | WebSocket, stream: Callable[[], AsyncIterator[str]], ) -> Response: return StreamingResponse( diff --git a/strawberry/channels/handlers/http_handler.py b/strawberry/channels/handlers/http_handler.py index f36b1be13f..b47fb3500d 100644 --- a/strawberry/channels/handlers/http_handler.py +++ b/strawberry/channels/handlers/http_handler.py @@ -335,7 +335,7 @@ def render_graphql_ide(self, request: ChannelsRequest) -> ChannelsResponse: # handlers in a threadpool. Check SyncConsumer's documentation for more info: # https://github.com/django/channels/blob/main/channels/consumer.py#L104 @database_sync_to_async # pyright: ignore[reportIncompatibleMethodOverride] - def run( # type: ignore[override] + def run( self, request: ChannelsRequest, context: Optional[Context] = UNSET, diff --git a/strawberry/schema/schema.py b/strawberry/schema/schema.py index f209b021a4..ec4bf0d64a 100644 --- a/strawberry/schema/schema.py +++ b/strawberry/schema/schema.py @@ -53,15 +53,12 @@ from strawberry.directive import StrawberryDirective from strawberry.extensions import SchemaExtension - from strawberry.field import StrawberryField - from strawberry.type import StrawberryType from strawberry.types import ExecutionResult, SubscriptionExecutionResult from strawberry.types.base import StrawberryType from strawberry.types.enum import EnumDefinition from strawberry.types.field import StrawberryField from strawberry.types.scalar import ScalarDefinition, ScalarWrapper from strawberry.types.union import StrawberryUnion - from strawberry.union import StrawberryUnion DEFAULT_ALLOWED_OPERATION_TYPES = { OperationType.QUERY, diff --git a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py index fa0cb7c177..1275ecf304 100644 --- a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py @@ -256,7 +256,7 @@ async def handle_subscribe(self, message: SubscribeMessage) -> None: else: # create AsyncGenerator returning a single result async def get_result_source() -> AsyncIterator[ExecutionResult]: - yield await self.schema.execute( + yield await self.schema.execute( # type: ignore query=message.payload.query, variable_values=message.payload.variables, context_value=context, From 95e76bc39f0b27eeaf7971bf1c0c12843c31cf44 Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Fri, 30 Aug 2024 18:37:26 +0200 Subject: [PATCH 50/57] Sanic fix --- strawberry/sanic/views.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/strawberry/sanic/views.py b/strawberry/sanic/views.py index a45d703056..391923d4b8 100644 --- a/strawberry/sanic/views.py +++ b/strawberry/sanic/views.py @@ -195,7 +195,7 @@ async def create_multipart_response( await response.eof() - return response + return None __all__ = ["GraphQLView"] From b811ad1c12b2aedf22e9c2c188c23586080cffa0 Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Fri, 30 Aug 2024 23:48:43 +0200 Subject: [PATCH 51/57] Fixes --- strawberry/experimental/pydantic/_compat.py | 1 - strawberry/experimental/pydantic/error_type.py | 1 - strawberry/experimental/pydantic/fields.py | 1 - strawberry/experimental/pydantic/utils.py | 1 - strawberry/http/async_base_view.py | 12 +++++------- strawberry/sanic/views.py | 6 +++++- strawberry/schema/execute.py | 2 +- tests/experimental/pydantic/schema/test_basic.py | 2 +- tests/experimental/pydantic/schema/test_defaults.py | 1 - .../experimental/pydantic/schema/test_federation.py | 3 +-- .../pydantic/schema/test_forward_reference.py | 1 - tests/experimental/pydantic/schema/test_mutation.py | 1 - tests/experimental/pydantic/test_basic.py | 2 +- tests/experimental/pydantic/test_conversion.py | 2 +- tests/experimental/pydantic/test_error_type.py | 2 +- tests/experimental/pydantic/test_fields.py | 4 ++-- 16 files changed, 18 insertions(+), 24 deletions(-) diff --git a/strawberry/experimental/pydantic/_compat.py b/strawberry/experimental/pydantic/_compat.py index 9971c962b5..dc3c2fe08f 100644 --- a/strawberry/experimental/pydantic/_compat.py +++ b/strawberry/experimental/pydantic/_compat.py @@ -8,7 +8,6 @@ import pydantic from pydantic import BaseModel from pydantic.version import VERSION as PYDANTIC_VERSION - from strawberry.experimental.pydantic.exceptions import UnsupportedTypeError if TYPE_CHECKING: diff --git a/strawberry/experimental/pydantic/error_type.py b/strawberry/experimental/pydantic/error_type.py index bbedfe610b..d07284e145 100644 --- a/strawberry/experimental/pydantic/error_type.py +++ b/strawberry/experimental/pydantic/error_type.py @@ -15,7 +15,6 @@ ) from pydantic import BaseModel - from strawberry.experimental.pydantic._compat import ( CompatModelField, PydanticCompat, diff --git a/strawberry/experimental/pydantic/fields.py b/strawberry/experimental/pydantic/fields.py index 9cac486290..3ded38fc29 100644 --- a/strawberry/experimental/pydantic/fields.py +++ b/strawberry/experimental/pydantic/fields.py @@ -3,7 +3,6 @@ from typing_extensions import Annotated from pydantic import BaseModel - from strawberry.experimental.pydantic._compat import ( PydanticCompat, get_args, diff --git a/strawberry/experimental/pydantic/utils.py b/strawberry/experimental/pydantic/utils.py index 912553fb98..9987c2ee59 100644 --- a/strawberry/experimental/pydantic/utils.py +++ b/strawberry/experimental/pydantic/utils.py @@ -14,7 +14,6 @@ ) from pydantic import BaseModel - from strawberry.experimental.pydantic._compat import ( CompatModelField, PydanticCompat, diff --git a/strawberry/http/async_base_view.py b/strawberry/http/async_base_view.py index 6791aad882..587ff1ea37 100644 --- a/strawberry/http/async_base_view.py +++ b/strawberry/http/async_base_view.py @@ -222,14 +222,12 @@ def encode_multipart_data(self, data: Any, separator: str) -> str: def _stream_with_heartbeat( self, stream: Callable[[], AsyncGenerator[str, None]] ) -> Callable[[], AsyncGenerator[str, None]]: - """Adds a heartbeat to the stream, to prevent the connection from closing - when there are no messages being sent. - """ + """Adds a heartbeat to the stream, to prevent the connection from closing when there are no messages being sent.""" queue = asyncio.Queue[Tuple[bool, Any]](1) cancelling = False - async def drain(): + async def drain() -> None: try: async for item in stream(): await queue.put((False, item)) @@ -239,7 +237,7 @@ async def drain(): else: raise - async def heartbeat(): + async def heartbeat() -> None: while True: await queue.put((False, self.encode_multipart_data({}, "graphql"))) @@ -249,7 +247,7 @@ async def merged() -> AsyncGenerator[str, None]: heartbeat_task = asyncio.create_task(heartbeat()) task = asyncio.create_task(drain()) - async def cancel_tasks(): + async def cancel_tasks() -> None: nonlocal cancelling cancelling = True task.cancel() @@ -282,7 +280,7 @@ def _get_stream( result: SubscriptionExecutionResult, separator: str = "graphql", ) -> Callable[[], AsyncGenerator[str, None]]: - async def stream(): + async def stream() -> AsyncGenerator[str, None]: async for value in result: response = await self.process_result(request, value) yield self.encode_multipart_data({"payload": response}, separator) diff --git a/strawberry/sanic/views.py b/strawberry/sanic/views.py index 391923d4b8..455fc02e55 100644 --- a/strawberry/sanic/views.py +++ b/strawberry/sanic/views.py @@ -195,7 +195,11 @@ async def create_multipart_response( await response.eof() - return None + # returning the response will basically tell sanic to send it again + # to the client, so we return None to avoid that, and we ignore the type + # error mostly so we don't have to update the types everywhere for this + # corner case + return None # type: ignore __all__ = ["GraphQLView"] diff --git a/strawberry/schema/execute.py b/strawberry/schema/execute.py index 674415d75d..036f50c540 100644 --- a/strawberry/schema/execute.py +++ b/strawberry/schema/execute.py @@ -23,7 +23,6 @@ from strawberry.extensions.runner import SchemaExtensionsRunner from strawberry.schema.validation_rules.one_of import OneOfInputValidationRule from strawberry.types import ExecutionResult -from strawberry.types.execution import SubscriptionExecutionResult from strawberry.types.graphql import OperationType from .exceptions import InvalidOperationTypeError @@ -38,6 +37,7 @@ from strawberry.extensions import SchemaExtension from strawberry.types import ExecutionContext + from strawberry.types.execution import SubscriptionExecutionResult # duplicated because of https://github.com/mkdocstrings/griffe-typingdoc/issues/7 diff --git a/tests/experimental/pydantic/schema/test_basic.py b/tests/experimental/pydantic/schema/test_basic.py index f816208978..75b6411bd2 100644 --- a/tests/experimental/pydantic/schema/test_basic.py +++ b/tests/experimental/pydantic/schema/test_basic.py @@ -3,9 +3,9 @@ from enum import Enum from typing import List, Optional, Union -import pydantic import pytest +import pydantic import strawberry from tests.experimental.pydantic.utils import needs_pydantic_v1 diff --git a/tests/experimental/pydantic/schema/test_defaults.py b/tests/experimental/pydantic/schema/test_defaults.py index 2412e6b04c..6939f7a0ba 100644 --- a/tests/experimental/pydantic/schema/test_defaults.py +++ b/tests/experimental/pydantic/schema/test_defaults.py @@ -2,7 +2,6 @@ from typing import Optional import pydantic - import strawberry from strawberry.printer import print_schema from tests.conftest import skip_if_gql_32 diff --git a/tests/experimental/pydantic/schema/test_federation.py b/tests/experimental/pydantic/schema/test_federation.py index 47bd56c2f9..90a506d62e 100644 --- a/tests/experimental/pydantic/schema/test_federation.py +++ b/tests/experimental/pydantic/schema/test_federation.py @@ -1,8 +1,7 @@ import typing -from pydantic import BaseModel - import strawberry +from pydantic import BaseModel from strawberry.federation.schema_directives import Key diff --git a/tests/experimental/pydantic/schema/test_forward_reference.py b/tests/experimental/pydantic/schema/test_forward_reference.py index ebc94d4b37..23ad750b51 100644 --- a/tests/experimental/pydantic/schema/test_forward_reference.py +++ b/tests/experimental/pydantic/schema/test_forward_reference.py @@ -4,7 +4,6 @@ from typing import Optional import pydantic - import strawberry diff --git a/tests/experimental/pydantic/schema/test_mutation.py b/tests/experimental/pydantic/schema/test_mutation.py index 03e545eece..a1efa945cd 100644 --- a/tests/experimental/pydantic/schema/test_mutation.py +++ b/tests/experimental/pydantic/schema/test_mutation.py @@ -1,7 +1,6 @@ from typing import Dict, List, Union import pydantic - import strawberry from strawberry.experimental.pydantic._compat import IS_PYDANTIC_V2 diff --git a/tests/experimental/pydantic/test_basic.py b/tests/experimental/pydantic/test_basic.py index 9556d3f617..2540ff99f1 100644 --- a/tests/experimental/pydantic/test_basic.py +++ b/tests/experimental/pydantic/test_basic.py @@ -3,9 +3,9 @@ from typing import Any, List, Optional, Union from typing_extensions import Annotated -import pydantic import pytest +import pydantic import strawberry from strawberry.experimental.pydantic.exceptions import MissingFieldsListError from strawberry.schema_directive import Location diff --git a/tests/experimental/pydantic/test_conversion.py b/tests/experimental/pydantic/test_conversion.py index cc9b5a81e0..eb4f7c0495 100644 --- a/tests/experimental/pydantic/test_conversion.py +++ b/tests/experimental/pydantic/test_conversion.py @@ -6,9 +6,9 @@ from typing import Any, Dict, List, NewType, Optional, TypeVar, Union import pytest -from pydantic import BaseModel, Field, ValidationError import strawberry +from pydantic import BaseModel, Field, ValidationError from strawberry.experimental.pydantic._compat import ( IS_PYDANTIC_V2, CompatModelField, diff --git a/tests/experimental/pydantic/test_error_type.py b/tests/experimental/pydantic/test_error_type.py index 8e37c6402c..37af6e51e3 100644 --- a/tests/experimental/pydantic/test_error_type.py +++ b/tests/experimental/pydantic/test_error_type.py @@ -1,8 +1,8 @@ from typing import List, Optional -import pydantic import pytest +import pydantic import strawberry from strawberry.experimental.pydantic.exceptions import MissingFieldsListError from strawberry.types.base import ( diff --git a/tests/experimental/pydantic/test_fields.py b/tests/experimental/pydantic/test_fields.py index 878969b9af..884236f704 100644 --- a/tests/experimental/pydantic/test_fields.py +++ b/tests/experimental/pydantic/test_fields.py @@ -2,11 +2,11 @@ from typing import List from typing_extensions import Literal -import pydantic import pytest -from pydantic import BaseModel, ValidationError, conlist +import pydantic import strawberry +from pydantic import BaseModel, ValidationError, conlist from strawberry.experimental.pydantic._compat import IS_PYDANTIC_V1 from strawberry.types.base import StrawberryObjectDefinition, StrawberryOptional from tests.experimental.pydantic.utils import needs_pydantic_v1, needs_pydantic_v2 From a74632f74afdb342e14ec8f619140c8c7c4fe5b4 Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Sat, 31 Aug 2024 12:01:14 +0200 Subject: [PATCH 52/57] Update release file and add tweet file --- RELEASE.md | 8 ++++---- TWEET.md | 9 +++++++++ 2 files changed, 13 insertions(+), 4 deletions(-) create mode 100644 TWEET.md diff --git a/RELEASE.md b/RELEASE.md index 749e0fcc41..a412cb7002 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,7 +1,7 @@ Release type: minor -This release adds support for multipart subscriptions in almost all\* -integrations! +This release adds support for multipart subscriptions in almost all[^1] of our +http integrations! [Multipart subcriptions](https://www.apollographql.com/docs/router/executing-operations/subscription-multipart-protocol/) are a new protocol from Apollo GraphQL, built on the @@ -10,9 +10,9 @@ which is also used for `@defer` and `@stream`. The main advantage of this protocol is that when using the Apollo Client libraries you don't need to install any additional dependency, but in future -this feature will make it easier for us to implement `@defer` and `@stream` +this feature should make it easier for us to implement `@defer` and `@stream` Also, this means that you don't need to use Django Channels for subscription, since this protocol is based on HTTP we don't need to use websockets. -- Flask, Chalice and the sync Django integration don't support this yet. +[^1]: Flask, Chalice and the sync Django integration don't support this. diff --git a/TWEET.md b/TWEET.md new file mode 100644 index 0000000000..bf2206ef4d --- /dev/null +++ b/TWEET.md @@ -0,0 +1,9 @@ +🆕 Release $version is out! Thanks to $contributor for the PR 👏 + +Strawberry GraphQL now supports @apollographql's multipart subscriptions! 🎉 + +This means that when using the Apollo Client libraries you don't need to install +any additional dependency, but in future this feature should make it easier for +us to implement `@defer` and `@stream`. + +Get it here 👉 $release_url From 1ba20741372c3aeae7a76da3f04e144ad8e6e116 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 31 Aug 2024 10:02:03 +0000 Subject: [PATCH 53/57] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- strawberry/experimental/pydantic/_compat.py | 1 + strawberry/experimental/pydantic/error_type.py | 1 + strawberry/experimental/pydantic/fields.py | 1 + strawberry/experimental/pydantic/utils.py | 1 + tests/experimental/pydantic/schema/test_basic.py | 2 +- tests/experimental/pydantic/schema/test_defaults.py | 1 + tests/experimental/pydantic/schema/test_federation.py | 3 ++- tests/experimental/pydantic/schema/test_forward_reference.py | 1 + tests/experimental/pydantic/schema/test_mutation.py | 1 + tests/experimental/pydantic/test_basic.py | 2 +- tests/experimental/pydantic/test_conversion.py | 2 +- tests/experimental/pydantic/test_error_type.py | 2 +- tests/experimental/pydantic/test_fields.py | 4 ++-- 13 files changed, 15 insertions(+), 7 deletions(-) diff --git a/strawberry/experimental/pydantic/_compat.py b/strawberry/experimental/pydantic/_compat.py index dc3c2fe08f..9971c962b5 100644 --- a/strawberry/experimental/pydantic/_compat.py +++ b/strawberry/experimental/pydantic/_compat.py @@ -8,6 +8,7 @@ import pydantic from pydantic import BaseModel from pydantic.version import VERSION as PYDANTIC_VERSION + from strawberry.experimental.pydantic.exceptions import UnsupportedTypeError if TYPE_CHECKING: diff --git a/strawberry/experimental/pydantic/error_type.py b/strawberry/experimental/pydantic/error_type.py index d07284e145..bbedfe610b 100644 --- a/strawberry/experimental/pydantic/error_type.py +++ b/strawberry/experimental/pydantic/error_type.py @@ -15,6 +15,7 @@ ) from pydantic import BaseModel + from strawberry.experimental.pydantic._compat import ( CompatModelField, PydanticCompat, diff --git a/strawberry/experimental/pydantic/fields.py b/strawberry/experimental/pydantic/fields.py index 3ded38fc29..9cac486290 100644 --- a/strawberry/experimental/pydantic/fields.py +++ b/strawberry/experimental/pydantic/fields.py @@ -3,6 +3,7 @@ from typing_extensions import Annotated from pydantic import BaseModel + from strawberry.experimental.pydantic._compat import ( PydanticCompat, get_args, diff --git a/strawberry/experimental/pydantic/utils.py b/strawberry/experimental/pydantic/utils.py index 9987c2ee59..912553fb98 100644 --- a/strawberry/experimental/pydantic/utils.py +++ b/strawberry/experimental/pydantic/utils.py @@ -14,6 +14,7 @@ ) from pydantic import BaseModel + from strawberry.experimental.pydantic._compat import ( CompatModelField, PydanticCompat, diff --git a/tests/experimental/pydantic/schema/test_basic.py b/tests/experimental/pydantic/schema/test_basic.py index 75b6411bd2..f816208978 100644 --- a/tests/experimental/pydantic/schema/test_basic.py +++ b/tests/experimental/pydantic/schema/test_basic.py @@ -3,9 +3,9 @@ from enum import Enum from typing import List, Optional, Union +import pydantic import pytest -import pydantic import strawberry from tests.experimental.pydantic.utils import needs_pydantic_v1 diff --git a/tests/experimental/pydantic/schema/test_defaults.py b/tests/experimental/pydantic/schema/test_defaults.py index 6939f7a0ba..2412e6b04c 100644 --- a/tests/experimental/pydantic/schema/test_defaults.py +++ b/tests/experimental/pydantic/schema/test_defaults.py @@ -2,6 +2,7 @@ from typing import Optional import pydantic + import strawberry from strawberry.printer import print_schema from tests.conftest import skip_if_gql_32 diff --git a/tests/experimental/pydantic/schema/test_federation.py b/tests/experimental/pydantic/schema/test_federation.py index 90a506d62e..47bd56c2f9 100644 --- a/tests/experimental/pydantic/schema/test_federation.py +++ b/tests/experimental/pydantic/schema/test_federation.py @@ -1,7 +1,8 @@ import typing -import strawberry from pydantic import BaseModel + +import strawberry from strawberry.federation.schema_directives import Key diff --git a/tests/experimental/pydantic/schema/test_forward_reference.py b/tests/experimental/pydantic/schema/test_forward_reference.py index 23ad750b51..ebc94d4b37 100644 --- a/tests/experimental/pydantic/schema/test_forward_reference.py +++ b/tests/experimental/pydantic/schema/test_forward_reference.py @@ -4,6 +4,7 @@ from typing import Optional import pydantic + import strawberry diff --git a/tests/experimental/pydantic/schema/test_mutation.py b/tests/experimental/pydantic/schema/test_mutation.py index a1efa945cd..03e545eece 100644 --- a/tests/experimental/pydantic/schema/test_mutation.py +++ b/tests/experimental/pydantic/schema/test_mutation.py @@ -1,6 +1,7 @@ from typing import Dict, List, Union import pydantic + import strawberry from strawberry.experimental.pydantic._compat import IS_PYDANTIC_V2 diff --git a/tests/experimental/pydantic/test_basic.py b/tests/experimental/pydantic/test_basic.py index 2540ff99f1..9556d3f617 100644 --- a/tests/experimental/pydantic/test_basic.py +++ b/tests/experimental/pydantic/test_basic.py @@ -3,9 +3,9 @@ from typing import Any, List, Optional, Union from typing_extensions import Annotated +import pydantic import pytest -import pydantic import strawberry from strawberry.experimental.pydantic.exceptions import MissingFieldsListError from strawberry.schema_directive import Location diff --git a/tests/experimental/pydantic/test_conversion.py b/tests/experimental/pydantic/test_conversion.py index eb4f7c0495..cc9b5a81e0 100644 --- a/tests/experimental/pydantic/test_conversion.py +++ b/tests/experimental/pydantic/test_conversion.py @@ -6,9 +6,9 @@ from typing import Any, Dict, List, NewType, Optional, TypeVar, Union import pytest +from pydantic import BaseModel, Field, ValidationError import strawberry -from pydantic import BaseModel, Field, ValidationError from strawberry.experimental.pydantic._compat import ( IS_PYDANTIC_V2, CompatModelField, diff --git a/tests/experimental/pydantic/test_error_type.py b/tests/experimental/pydantic/test_error_type.py index 37af6e51e3..8e37c6402c 100644 --- a/tests/experimental/pydantic/test_error_type.py +++ b/tests/experimental/pydantic/test_error_type.py @@ -1,8 +1,8 @@ from typing import List, Optional +import pydantic import pytest -import pydantic import strawberry from strawberry.experimental.pydantic.exceptions import MissingFieldsListError from strawberry.types.base import ( diff --git a/tests/experimental/pydantic/test_fields.py b/tests/experimental/pydantic/test_fields.py index 884236f704..878969b9af 100644 --- a/tests/experimental/pydantic/test_fields.py +++ b/tests/experimental/pydantic/test_fields.py @@ -2,11 +2,11 @@ from typing import List from typing_extensions import Literal +import pydantic import pytest +from pydantic import BaseModel, ValidationError, conlist -import pydantic import strawberry -from pydantic import BaseModel, ValidationError, conlist from strawberry.experimental.pydantic._compat import IS_PYDANTIC_V1 from strawberry.types.base import StrawberryObjectDefinition, StrawberryOptional from tests.experimental.pydantic.utils import needs_pydantic_v1, needs_pydantic_v2 From bca86779bfe22348a77fd6c28ee967eaf41a70fe Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Sat, 31 Aug 2024 12:12:37 +0200 Subject: [PATCH 54/57] Fix tweet --- TWEET.md | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/TWEET.md b/TWEET.md index bf2206ef4d..da1ca72a69 100644 --- a/TWEET.md +++ b/TWEET.md @@ -2,8 +2,6 @@ Strawberry GraphQL now supports @apollographql's multipart subscriptions! 🎉 -This means that when using the Apollo Client libraries you don't need to install -any additional dependency, but in future this feature should make it easier for -us to implement `@defer` and `@stream`. +Now, when using the Apollo Client libraries you don't need to install any additional dependency 😊 Get it here 👉 $release_url From 3bfd9b61afdd85c405495a6a23e67acc57ed3a76 Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Sat, 31 Aug 2024 12:43:59 +0200 Subject: [PATCH 55/57] Update tweet --- TWEET.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/TWEET.md b/TWEET.md index da1ca72a69..17b5d5a1c2 100644 --- a/TWEET.md +++ b/TWEET.md @@ -2,6 +2,4 @@ Strawberry GraphQL now supports @apollographql's multipart subscriptions! 🎉 -Now, when using the Apollo Client libraries you don't need to install any additional dependency 😊 - Get it here 👉 $release_url From 259b1e5c03b743547f551a7c8397711fa3894870 Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Sat, 31 Aug 2024 13:07:04 +0200 Subject: [PATCH 56/57] Sub response support --- strawberry/aiohttp/views.py | 5 +++-- strawberry/asgi/__init__.py | 3 +++ strawberry/channels/handlers/http_handler.py | 11 ++++------- strawberry/django/views.py | 3 +++ strawberry/fastapi/router.py | 3 +++ strawberry/http/async_base_view.py | 3 ++- strawberry/litestar/controller.py | 3 +++ strawberry/quart/views.py | 4 +++- strawberry/sanic/views.py | 3 +++ 9 files changed, 27 insertions(+), 11 deletions(-) diff --git a/strawberry/aiohttp/views.py b/strawberry/aiohttp/views.py index a2cc4975aa..56a755b2c9 100644 --- a/strawberry/aiohttp/views.py +++ b/strawberry/aiohttp/views.py @@ -193,14 +193,15 @@ async def create_multipart_response( self, request: web.Request, stream: Callable[[], AsyncGenerator[str, None]], + sub_response: web.Response, ) -> web.StreamResponse: response = web.StreamResponse( - status=200, + status=sub_response.status, headers={ + **sub_response.headers, "Transfer-Encoding": "chunked", "Content-type": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json", }, - reason="OK", ) await response.prepare(request) diff --git a/strawberry/asgi/__init__.py b/strawberry/asgi/__init__.py index ee3df73333..26ead659e6 100644 --- a/strawberry/asgi/__init__.py +++ b/strawberry/asgi/__init__.py @@ -225,10 +225,13 @@ async def create_multipart_response( self, request: Request | WebSocket, stream: Callable[[], AsyncIterator[str]], + sub_response: Response, ) -> Response: return StreamingResponse( stream(), + status_code=sub_response.status_code, headers={ + **sub_response.headers, "Transfer-Encoding": "chunked", "Content-type": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json", }, diff --git a/strawberry/channels/handlers/http_handler.py b/strawberry/channels/handlers/http_handler.py index b47fb3500d..e7a96d1d7b 100644 --- a/strawberry/channels/handlers/http_handler.py +++ b/strawberry/channels/handlers/http_handler.py @@ -1,8 +1,3 @@ -"""GraphQLHTTPHandler. - -A consumer to provide a graphql endpoint, and optionally graphiql. -""" - from __future__ import annotations import dataclasses @@ -282,9 +277,11 @@ async def create_multipart_response( self, request: ChannelsRequest, stream: Callable[[], AsyncGenerator[str, None]], + sub_response: TemporalResponse, ) -> MultipartChannelsResponse: - # TODO: sub response - return MultipartChannelsResponse(stream=stream) + status = sub_response.status_code or 200 + headers = {k.encode(): v.encode() for k, v in sub_response.headers.items()} + return MultipartChannelsResponse(stream=stream, status=status, headers=headers) async def render_graphql_ide(self, request: ChannelsRequest) -> ChannelsResponse: return ChannelsResponse( diff --git a/strawberry/django/views.py b/strawberry/django/views.py index 932beac5dd..ee831eabcb 100644 --- a/strawberry/django/views.py +++ b/strawberry/django/views.py @@ -189,10 +189,13 @@ async def create_multipart_response( self, request: HttpRequest, stream: Callable[[], AsyncIterator[Any]], + sub_response: TemporalHttpResponse, ) -> HttpResponseBase: return StreamingHttpResponse( streaming_content=stream(), + status=sub_response.status_code, headers={ + **sub_response.headers, "Transfer-Encoding": "chunked", "Content-type": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json", }, diff --git a/strawberry/fastapi/router.py b/strawberry/fastapi/router.py index b503d49af7..e25dfcd820 100644 --- a/strawberry/fastapi/router.py +++ b/strawberry/fastapi/router.py @@ -336,10 +336,13 @@ async def create_multipart_response( self, request: Request, stream: Callable[[], AsyncIterator[str]], + sub_response: Response, ) -> Response: return StreamingResponse( stream(), + status_code=sub_response.status_code, headers={ + **sub_response.headers, "Transfer-Encoding": "chunked", "Content-type": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json", }, diff --git a/strawberry/http/async_base_view.py b/strawberry/http/async_base_view.py index 587ff1ea37..e210861a55 100644 --- a/strawberry/http/async_base_view.py +++ b/strawberry/http/async_base_view.py @@ -96,6 +96,7 @@ async def create_multipart_response( self, request: Request, stream: Callable[[], AsyncGenerator[str, None]], + sub_response: SubResponse, ) -> Response: raise ValueError("Multipart responses are not supported") @@ -198,7 +199,7 @@ async def run( if isinstance(result, SubscriptionExecutionResult): stream = self._get_stream(request, result) - return await self.create_multipart_response(request, stream) + return await self.create_multipart_response(request, stream, sub_response) response_data = await self.process_result(request=request, result=result) diff --git a/strawberry/litestar/controller.py b/strawberry/litestar/controller.py index 09a5a5db67..dc4e37a0af 100644 --- a/strawberry/litestar/controller.py +++ b/strawberry/litestar/controller.py @@ -284,10 +284,13 @@ async def create_multipart_response( self, request: Request, stream: Callable[[], AsyncIterator[str]], + sub_response: Response, ) -> Response: return Stream( stream(), + status_code=sub_response.status_code, headers={ + **sub_response.headers, "Transfer-Encoding": "chunked", "Content-type": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json", }, diff --git a/strawberry/quart/views.py b/strawberry/quart/views.py index f841501744..e6938a6034 100644 --- a/strawberry/quart/views.py +++ b/strawberry/quart/views.py @@ -107,11 +107,13 @@ async def create_multipart_response( self, request: Request, stream: Callable[[], AsyncGenerator[str, None]], + sub_response: Response, ) -> Response: return ( stream(), - 200, + sub_response.status_code, { # type: ignore + **sub_response.headers, "Transfer-Encoding": "chunked", "Content-type": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json", }, diff --git a/strawberry/sanic/views.py b/strawberry/sanic/views.py index 455fc02e55..83b7d3ca5c 100644 --- a/strawberry/sanic/views.py +++ b/strawberry/sanic/views.py @@ -182,10 +182,13 @@ async def create_multipart_response( self, request: Request, stream: Callable[[], AsyncGenerator[str, None]], + sub_response: TemporalResponse, ) -> HTTPResponse: response = await self.request.respond( content_type="multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json", + status=sub_response.status_code, headers={ + **sub_response.headers, "Transfer-Encoding": "chunked", }, ) From 20e8bf7bb7a4f5bac1d12033b4ca0a708e1112f8 Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Sat, 31 Aug 2024 13:07:48 +0200 Subject: [PATCH 57/57] Update docs --- docs/general/multipart-subscriptions.md | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/docs/general/multipart-subscriptions.md b/docs/general/multipart-subscriptions.md index 6da191ca72..bbaeba768c 100644 --- a/docs/general/multipart-subscriptions.md +++ b/docs/general/multipart-subscriptions.md @@ -25,14 +25,3 @@ libraries: Multipart subscriptions are automatically enabled when using Subscription, so no additional configuration is required. - -# Limitations - -At the moment, we don't support the following features: - -- Changing the status code of the response -- Changing the headers of the response - -We might add support for these features in the future, but it's clear how they -would work in the context of a subscription. If you have any ideas feel free to -reach out on our [discord server](https://strawberry.rocks/discord).