diff --git a/src/graphql/execution/execute.py b/src/graphql/execution/execute.py index b49bf981..74356fa0 100644 --- a/src/graphql/execution/execute.py +++ b/src/graphql/execution/execute.py @@ -2043,6 +2043,7 @@ def subscribe( type_resolver: GraphQLTypeResolver | None = None, subscribe_field_resolver: GraphQLFieldResolver | None = None, execution_context_class: type[ExecutionContext] | None = None, + middleware: MiddlewareManager | None = None, ) -> AwaitableOrValue[AsyncIterator[ExecutionResult] | ExecutionResult]: """Create a GraphQL subscription. @@ -2082,6 +2083,7 @@ def subscribe( field_resolver, type_resolver, subscribe_field_resolver, + middleware=middleware, ) # Return early errors if execution context failed. diff --git a/tests/execution/test_middleware.py b/tests/execution/test_middleware.py index 4927b52f..d4abba95 100644 --- a/tests/execution/test_middleware.py +++ b/tests/execution/test_middleware.py @@ -1,7 +1,8 @@ +import inspect from typing import Awaitable, cast import pytest -from graphql.execution import Middleware, MiddlewareManager, execute +from graphql.execution import Middleware, MiddlewareManager, execute, subscribe from graphql.language.parser import parse from graphql.type import GraphQLField, GraphQLObjectType, GraphQLSchema, GraphQLString @@ -236,6 +237,45 @@ async def resolve(self, next_, *args, **kwargs): result = await awaitable_result assert result.data == {"field": "devloseR"} + @pytest.mark.asyncio() + async def subscription_simple(): + async def bar_resolve(_obj, _info): + yield "bar" + yield "oof" + + test_type = GraphQLObjectType( + "Subscription", + { + "bar": GraphQLField( + GraphQLString, + resolve=lambda message, _info: message, + subscribe=bar_resolve, + ), + }, + ) + doc = parse("subscription { bar }") + + async def reverse_middleware(next_, value, info, **kwargs): + awaitable_maybe = next_(value, info, **kwargs) + return awaitable_maybe[::-1] + + noop_type = GraphQLObjectType( + "Noop", + {"noop": GraphQLField(GraphQLString)}, + ) + schema = GraphQLSchema(query=noop_type, subscription=test_type) + + agen = subscribe( + schema, + doc, + middleware=MiddlewareManager(reverse_middleware), + ) + assert inspect.isasyncgen(agen) + data = (await agen.__anext__()).data + assert data == {"bar": "rab"} + data = (await agen.__anext__()).data + assert data == {"bar": "foo"} + def describe_without_manager(): def no_middleware(): doc = parse("{ field }")