Skip to content

Commit

Permalink
Support middlewares for subscriptions (#221)
Browse files Browse the repository at this point in the history
  • Loading branch information
nrbnlulu authored Jul 7, 2024
1 parent a5a2a65 commit 876aef6
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 1 deletion.
2 changes: 2 additions & 0 deletions src/graphql/execution/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -2043,6 +2043,7 @@ def subscribe(
type_resolver: GraphQLTypeResolver | None = None,
subscribe_field_resolver: GraphQLFieldResolver | None = None,
execution_context_class: type[ExecutionContext] | None = None,
middleware: MiddlewareManager | None = None,
) -> AwaitableOrValue[AsyncIterator[ExecutionResult] | ExecutionResult]:
"""Create a GraphQL subscription.
Expand Down Expand Up @@ -2082,6 +2083,7 @@ def subscribe(
field_resolver,
type_resolver,
subscribe_field_resolver,
middleware=middleware,
)

# Return early errors if execution context failed.
Expand Down
42 changes: 41 additions & 1 deletion tests/execution/test_middleware.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import inspect
from typing import Awaitable, cast

import pytest
from graphql.execution import Middleware, MiddlewareManager, execute
from graphql.execution import Middleware, MiddlewareManager, execute, subscribe
from graphql.language.parser import parse
from graphql.type import GraphQLField, GraphQLObjectType, GraphQLSchema, GraphQLString

Expand Down Expand Up @@ -236,6 +237,45 @@ async def resolve(self, next_, *args, **kwargs):
result = await awaitable_result
assert result.data == {"field": "devloseR"}

@pytest.mark.asyncio()
async def subscription_simple():
async def bar_resolve(_obj, _info):
yield "bar"
yield "oof"

test_type = GraphQLObjectType(
"Subscription",
{
"bar": GraphQLField(
GraphQLString,
resolve=lambda message, _info: message,
subscribe=bar_resolve,
),
},
)
doc = parse("subscription { bar }")

async def reverse_middleware(next_, value, info, **kwargs):
awaitable_maybe = next_(value, info, **kwargs)
return awaitable_maybe[::-1]

noop_type = GraphQLObjectType(
"Noop",
{"noop": GraphQLField(GraphQLString)},
)
schema = GraphQLSchema(query=noop_type, subscription=test_type)

agen = subscribe(
schema,
doc,
middleware=MiddlewareManager(reverse_middleware),
)
assert inspect.isasyncgen(agen)
data = (await agen.__anext__()).data
assert data == {"bar": "rab"}
data = (await agen.__anext__()).data
assert data == {"bar": "foo"}

def describe_without_manager():
def no_middleware():
doc = parse("{ field }")
Expand Down

0 comments on commit 876aef6

Please sign in to comment.