Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support graphql-transport-ws websocket subprotocol #65

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions graphql_ws/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@
from graphql import format_error, graphql

from .constants import (
GQL_COMPLETE,
GQL_CONNECTION_ERROR,
GQL_CONNECTION_INIT,
GQL_CONNECTION_TERMINATE,
GQL_DATA,
GQL_ERROR,
GQL_NEXT,
GQL_START,
GQL_STOP,
GQL_SUBSCRIBE,
TRANSPORT_WS_PROTOCOL,
)


Expand All @@ -19,10 +23,15 @@ class ConnectionClosedException(Exception):


class BaseConnectionContext(object):
transport_ws_protocol = False

def __init__(self, ws, request_context=None):
self.ws = ws
self.operations = {}
self.request_context = request_context
self.transport_ws_protocol = request_context and TRANSPORT_WS_PROTOCOL in (
request_context.get("subprotocols") or []
)

def has_operation(self, op_id):
return op_id in self.operations
Expand All @@ -41,7 +50,7 @@ def remove_operation(self, op_id):

def unsubscribe(self, op_id):
async_iterator = self.remove_operation(op_id)
if hasattr(async_iterator, 'dispose'):
if hasattr(async_iterator, "dispose"):
async_iterator.dispose()
return async_iterator

Expand Down Expand Up @@ -84,12 +93,16 @@ def process_message(self, connection_context, parsed_message):
elif op_type == GQL_CONNECTION_TERMINATE:
return self.on_connection_terminate(connection_context, op_id)

elif op_type == GQL_START:
elif op_type == (
GQL_SUBSCRIBE if connection_context.transport_ws_protocol else GQL_START
):
assert isinstance(payload, dict), "The payload must be a dict"
params = self.get_graphql_params(connection_context, payload)
return self.on_start(connection_context, op_id, params)

elif op_type == GQL_STOP:
elif op_type == (
GQL_COMPLETE if connection_context.transport_ws_protocol else GQL_STOP
):
return self.on_stop(connection_context, op_id)

else:
Expand Down Expand Up @@ -142,7 +155,12 @@ def build_message(self, id, op_type, payload):

def send_execution_result(self, connection_context, op_id, execution_result):
result = self.execution_result_to_dict(execution_result)
return self.send_message(connection_context, op_id, GQL_DATA, result)
return self.send_message(
connection_context,
op_id,
GQL_NEXT if connection_context.transport_ws_protocol else GQL_DATA,
result,
)

def execution_result_to_dict(self, execution_result):
result = OrderedDict()
Expand Down
10 changes: 7 additions & 3 deletions graphql_ws/constants.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
GRAPHQL_WS = "graphql-ws"
WS_PROTOCOL = GRAPHQL_WS
TRANSPORT_WS_PROTOCOL = "graphql-transport-ws"

GQL_CONNECTION_INIT = "connection_init" # Client -> Server
GQL_CONNECTION_ACK = "connection_ack" # Server -> Client
Expand All @@ -8,8 +9,11 @@
# NOTE: This one here don't follow the standard due to connection optimization
GQL_CONNECTION_TERMINATE = "connection_terminate" # Client -> Server
GQL_CONNECTION_KEEP_ALIVE = "ka" # Server -> Client
GQL_START = "start" # Client -> Server
GQL_DATA = "data" # Server -> Client
GQL_START = "start" # Client -> Server (graphql-ws)
GQL_SUBSCRIBE = "subscribe" # Client -> Server (graphql-transport-ws START equivalent)
GQL_DATA = "data" # Server -> Client (graphql-ws)
GQL_NEXT = "next" # Server -> Client (graphql-transport-ws DATA equivalent)
GQL_ERROR = "error" # Server -> Client
GQL_COMPLETE = "complete" # Server -> Client
GQL_STOP = "stop" # Client -> Server
# (and Client -> Server for graphql-transport-ws STOP equivalent)
GQL_STOP = "stop" # Client -> Server (graphql-ws only)
20 changes: 12 additions & 8 deletions graphql_ws/django/consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,25 @@

from channels.generic.websocket import AsyncJsonWebsocketConsumer

from ..constants import WS_PROTOCOL
from ..constants import TRANSPORT_WS_PROTOCOL, WS_PROTOCOL
from .subscriptions import subscription_server


class GraphQLSubscriptionConsumer(AsyncJsonWebsocketConsumer):

async def connect(self):
self.connection_context = None
if WS_PROTOCOL in self.scope["subprotocols"]:
self.connection_context = await subscription_server.handle(
ws=self, request_context=self.scope
)
await self.accept(subprotocol=WS_PROTOCOL)
else:
found_protocol = None
for protocol in [WS_PROTOCOL, TRANSPORT_WS_PROTOCOL]:
if protocol in self.scope["subprotocols"]:
found_protocol = protocol
break
if not found_protocol:
await self.close()
return
self.connection_context = await subscription_server.handle(
ws=self, request_context=self.scope
)
await self.accept(subprotocol=found_protocol)

async def disconnect(self, code):
if self.connection_context:
Expand Down
3 changes: 1 addition & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,8 @@ test =
pytest-asyncio; python_version>="3.4"
graphene>=2.0,<3
gevent
graphene>=2.0
graphene_django
mock; python_version<"3"
mock; python_version<"3.8"
django==1.11.*; python_version<"3"
channels==1.*; python_version<"3"
django==2.*; python_version>="3"
Expand Down
40 changes: 22 additions & 18 deletions tests/test_base_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
pytestmark = pytest.mark.asyncio


class AsyncMock(mock.MagicMock):
async def __call__(self, *args, **kwargs):
return super().__call__(*args, **kwargs)
try:
from unittest.mock import AsyncMock # Python 3.8+
except ImportError:
from mock import AsyncMock


class TstServer(base_async.BaseAsyncSubscriptionServer):
Expand All @@ -26,75 +27,78 @@ def server():


async def test_terminate(server: TstServer):
context = AsyncMock()
context = AsyncMock(spec=base_async.BaseAsyncConnectionContext)
await server.on_connection_terminate(connection_context=context, op_id=1)
context.close.assert_called_with(1011)


async def test_send_error(server: TstServer):
context = AsyncMock()
context.has_operation = mock.Mock()
context = AsyncMock(spec=base_async.BaseAsyncConnectionContext)
await server.send_error(connection_context=context, op_id=1, error="test error")
context.send.assert_called_with(
{"id": 1, "type": "error", "payload": {"message": "test error"}}
)


async def test_message(server):
async def test_message(server: TstServer):
server.process_message = AsyncMock()
context = AsyncMock()
context = AsyncMock(spec=base_async.BaseAsyncConnectionContext)
msg = {"id": 1, "type": base.GQL_CONNECTION_INIT, "payload": ""}
await server.on_message(context, msg)
server.process_message.assert_called_with(context, msg)


async def test_message_str(server):
async def test_message_str(server: TstServer):
server.process_message = AsyncMock()
context = AsyncMock()
context = AsyncMock(spec=base_async.BaseAsyncConnectionContext)
msg = {"id": 1, "type": base.GQL_CONNECTION_INIT, "payload": ""}
await server.on_message(context, json.dumps(msg))
server.process_message.assert_called_with(context, msg)


async def test_message_invalid(server):
async def test_message_invalid(server: TstServer):
server.send_error = AsyncMock()
await server.on_message(connection_context=None, message="'not-json")
context = AsyncMock(spec=base_async.BaseAsyncConnectionContext)
await server.on_message(context, message="'not-json")
assert server.send_error.called


async def test_resolver(server):
async def test_resolver(server: TstServer):
server.send_message = AsyncMock()
context = AsyncMock(spec=base_async.BaseAsyncConnectionContext)
result = mock.Mock()
result.data = {"test": [1, 2]}
result.errors = None
await server.send_execution_result(
connection_context=None, op_id=1, execution_result=result
context, op_id=1, execution_result=result
)
assert server.send_message.called


@pytest.mark.asyncio
async def test_resolver_with_promise(server):
async def test_resolver_with_promise(server: TstServer):
server.send_message = AsyncMock()
context = AsyncMock(spec=base_async.BaseAsyncConnectionContext)
result = mock.Mock()
result.data = {"test": [1, promise.Promise(lambda resolve, reject: resolve(2))]}
result.errors = None
await server.send_execution_result(
connection_context=None, op_id=1, execution_result=result
context, op_id=1, execution_result=result
)
assert server.send_message.called
assert result.data == {'test': [1, 2]}


async def test_resolver_with_nested_promise(server):
async def test_resolver_with_nested_promise(server: TstServer):
server.send_message = AsyncMock()
context = AsyncMock(spec=base_async.BaseAsyncConnectionContext)
result = mock.Mock()
inner = promise.Promise(lambda resolve, reject: resolve(2))
outer = promise.Promise(lambda resolve, reject: resolve({'in': inner}))
result.data = {"test": [1, outer]}
result.errors = None
await server.send_execution_result(
connection_context=None, op_id=1, execution_result=result
context, op_id=1, execution_result=result
)
assert server.send_message.called
assert result.data == {'test': [1, {'in': 2}]}
45 changes: 39 additions & 6 deletions tests/test_graphql_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,20 @@ def test_terminate(self, ss, cc):
ss.process_message(cc, {"id": "1", "type": constants.GQL_CONNECTION_TERMINATE})
ss.on_connection_terminate.assert_called_with(cc, "1")

def test_start(self, ss, cc):
@pytest.mark.parametrize(
"transport_ws_protocol,expected_type",
((False, constants.GQL_START), (True, constants.GQL_SUBSCRIBE)),
)
def test_start(self, ss, cc, transport_ws_protocol, expected_type):
ss.get_graphql_params = mock.Mock()
ss.get_graphql_params.return_value = {"params": True}
cc.has_operation = mock.Mock()
cc.has_operation.return_value = False
cc.transport_ws_protocol = transport_ws_protocol
ss.unsubscribe = mock.Mock()
ss.on_start = mock.Mock()
ss.process_message(
cc, {"id": "1", "type": constants.GQL_START, "payload": {"a": "b"}}
cc, {"id": "1", "type": expected_type, "payload": {"a": "b"}}
)
assert not ss.unsubscribe.called
ss.on_start.assert_called_with(cc, "1", {"params": True})
Expand Down Expand Up @@ -117,9 +122,32 @@ def test_start_bad_graphql_params(self, ss, cc):
assert isinstance(ss.send_error.call_args[0][2], Exception)
assert not ss.on_start.called

def test_stop(self, ss, cc):
@pytest.mark.parametrize(
"transport_ws_protocol,stop_type,invalid_stop_type",
(
(False, constants.GQL_STOP, constants.GQL_COMPLETE),
(True, constants.GQL_COMPLETE, constants.GQL_STOP),
),
)
def test_stop(
self,
ss,
cc,
transport_ws_protocol,
stop_type,
invalid_stop_type,
):
ss.on_stop = mock.Mock()
ss.process_message(cc, {"id": "1", "type": constants.GQL_STOP})
ss.send_error = mock.Mock()
cc.transport_ws_protocol = transport_ws_protocol

ss.process_message(cc, {"id": "1", "type": invalid_stop_type})
assert ss.send_error.called
assert ss.send_error.call_args[0][:2] == (cc, "1")
assert isinstance(ss.send_error.call_args[0][2], Exception)
assert not ss.on_stop.called

ss.process_message(cc, {"id": "1", "type": stop_type})
ss.on_stop.assert_called_with(cc, "1")

def test_invalid(self, ss, cc):
Expand Down Expand Up @@ -165,13 +193,18 @@ def test_build_message_partial(ss):
ss.build_message(id=None, op_type=None, payload=None)


def test_send_execution_result(ss):
@pytest.mark.parametrize(
"transport_ws_protocol,expected_type",
((False, constants.GQL_DATA), (True, constants.GQL_NEXT)),
)
def test_send_execution_result(ss, cc, transport_ws_protocol, expected_type):
cc.transport_ws_protocol = transport_ws_protocol
ss.execution_result_to_dict = mock.Mock()
ss.execution_result_to_dict.return_value = {"res": "ult"}
ss.send_message = mock.Mock()
ss.send_message.return_value = "returned"
assert "returned" == ss.send_execution_result(cc, "1", "result")
ss.send_message.assert_called_with(cc, "1", constants.GQL_DATA, {"res": "ult"})
ss.send_message.assert_called_with(cc, "1", expected_type, {"res": "ult"})


def test_execution_result_to_dict(ss):
Expand Down