diff --git a/graphql_ws/base.py b/graphql_ws/base.py index 31ad657..e69f12e 100644 --- a/graphql_ws/base.py +++ b/graphql_ws/base.py @@ -4,13 +4,17 @@ from graphql import format_error, graphql from .constants import ( + GQL_COMPLETE, GQL_CONNECTION_ERROR, GQL_CONNECTION_INIT, GQL_CONNECTION_TERMINATE, GQL_DATA, GQL_ERROR, + GQL_NEXT, GQL_START, GQL_STOP, + GQL_SUBSCRIBE, + TRANSPORT_WS_PROTOCOL, ) @@ -19,10 +23,15 @@ class ConnectionClosedException(Exception): class BaseConnectionContext(object): + transport_ws_protocol = False + def __init__(self, ws, request_context=None): self.ws = ws self.operations = {} self.request_context = request_context + self.transport_ws_protocol = request_context and TRANSPORT_WS_PROTOCOL in ( + request_context.get("subprotocols") or [] + ) def has_operation(self, op_id): return op_id in self.operations @@ -41,7 +50,7 @@ def remove_operation(self, op_id): def unsubscribe(self, op_id): async_iterator = self.remove_operation(op_id) - if hasattr(async_iterator, 'dispose'): + if hasattr(async_iterator, "dispose"): async_iterator.dispose() return async_iterator @@ -84,12 +93,16 @@ def process_message(self, connection_context, parsed_message): elif op_type == GQL_CONNECTION_TERMINATE: return self.on_connection_terminate(connection_context, op_id) - elif op_type == GQL_START: + elif op_type == ( + GQL_SUBSCRIBE if connection_context.transport_ws_protocol else GQL_START + ): assert isinstance(payload, dict), "The payload must be a dict" params = self.get_graphql_params(connection_context, payload) return self.on_start(connection_context, op_id, params) - elif op_type == GQL_STOP: + elif op_type == ( + GQL_COMPLETE if connection_context.transport_ws_protocol else GQL_STOP + ): return self.on_stop(connection_context, op_id) else: @@ -142,7 +155,12 @@ def build_message(self, id, op_type, payload): def send_execution_result(self, connection_context, op_id, execution_result): result = self.execution_result_to_dict(execution_result) - return self.send_message(connection_context, op_id, GQL_DATA, result) + return self.send_message( + connection_context, + op_id, + GQL_NEXT if connection_context.transport_ws_protocol else GQL_DATA, + result, + ) def execution_result_to_dict(self, execution_result): result = OrderedDict() diff --git a/graphql_ws/constants.py b/graphql_ws/constants.py index 8b57a60..2952296 100644 --- a/graphql_ws/constants.py +++ b/graphql_ws/constants.py @@ -1,5 +1,6 @@ GRAPHQL_WS = "graphql-ws" WS_PROTOCOL = GRAPHQL_WS +TRANSPORT_WS_PROTOCOL = "graphql-transport-ws" GQL_CONNECTION_INIT = "connection_init" # Client -> Server GQL_CONNECTION_ACK = "connection_ack" # Server -> Client @@ -8,8 +9,11 @@ # NOTE: This one here don't follow the standard due to connection optimization GQL_CONNECTION_TERMINATE = "connection_terminate" # Client -> Server GQL_CONNECTION_KEEP_ALIVE = "ka" # Server -> Client -GQL_START = "start" # Client -> Server -GQL_DATA = "data" # Server -> Client +GQL_START = "start" # Client -> Server (graphql-ws) +GQL_SUBSCRIBE = "subscribe" # Client -> Server (graphql-transport-ws START equivalent) +GQL_DATA = "data" # Server -> Client (graphql-ws) +GQL_NEXT = "next" # Server -> Client (graphql-transport-ws DATA equivalent) GQL_ERROR = "error" # Server -> Client GQL_COMPLETE = "complete" # Server -> Client -GQL_STOP = "stop" # Client -> Server +# (and Client -> Server for graphql-transport-ws STOP equivalent) +GQL_STOP = "stop" # Client -> Server (graphql-ws only) diff --git a/graphql_ws/django/consumers.py b/graphql_ws/django/consumers.py index b1c64d1..1dc10ab 100644 --- a/graphql_ws/django/consumers.py +++ b/graphql_ws/django/consumers.py @@ -2,21 +2,25 @@ from channels.generic.websocket import AsyncJsonWebsocketConsumer -from ..constants import WS_PROTOCOL +from ..constants import TRANSPORT_WS_PROTOCOL, WS_PROTOCOL from .subscriptions import subscription_server class GraphQLSubscriptionConsumer(AsyncJsonWebsocketConsumer): - async def connect(self): self.connection_context = None - if WS_PROTOCOL in self.scope["subprotocols"]: - self.connection_context = await subscription_server.handle( - ws=self, request_context=self.scope - ) - await self.accept(subprotocol=WS_PROTOCOL) - else: + found_protocol = None + for protocol in [WS_PROTOCOL, TRANSPORT_WS_PROTOCOL]: + if protocol in self.scope["subprotocols"]: + found_protocol = protocol + break + if not found_protocol: await self.close() + return + self.connection_context = await subscription_server.handle( + ws=self, request_context=self.scope + ) + await self.accept(subprotocol=found_protocol) async def disconnect(self, code): if self.connection_context: diff --git a/setup.cfg b/setup.cfg index 3d07a80..ded02e8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -50,9 +50,8 @@ test = pytest-asyncio; python_version>="3.4" graphene>=2.0,<3 gevent - graphene>=2.0 graphene_django - mock; python_version<"3" + mock; python_version<"3.8" django==1.11.*; python_version<"3" channels==1.*; python_version<"3" django==2.*; python_version>="3" diff --git a/tests/test_base_async.py b/tests/test_base_async.py index d62eda5..50c4309 100644 --- a/tests/test_base_async.py +++ b/tests/test_base_async.py @@ -10,9 +10,10 @@ pytestmark = pytest.mark.asyncio -class AsyncMock(mock.MagicMock): - async def __call__(self, *args, **kwargs): - return super().__call__(*args, **kwargs) +try: + from unittest.mock import AsyncMock # Python 3.8+ +except ImportError: + from mock import AsyncMock class TstServer(base_async.BaseAsyncSubscriptionServer): @@ -26,75 +27,78 @@ def server(): async def test_terminate(server: TstServer): - context = AsyncMock() + context = AsyncMock(spec=base_async.BaseAsyncConnectionContext) await server.on_connection_terminate(connection_context=context, op_id=1) context.close.assert_called_with(1011) async def test_send_error(server: TstServer): - context = AsyncMock() - context.has_operation = mock.Mock() + context = AsyncMock(spec=base_async.BaseAsyncConnectionContext) await server.send_error(connection_context=context, op_id=1, error="test error") context.send.assert_called_with( {"id": 1, "type": "error", "payload": {"message": "test error"}} ) -async def test_message(server): +async def test_message(server: TstServer): server.process_message = AsyncMock() - context = AsyncMock() + context = AsyncMock(spec=base_async.BaseAsyncConnectionContext) msg = {"id": 1, "type": base.GQL_CONNECTION_INIT, "payload": ""} await server.on_message(context, msg) server.process_message.assert_called_with(context, msg) -async def test_message_str(server): +async def test_message_str(server: TstServer): server.process_message = AsyncMock() - context = AsyncMock() + context = AsyncMock(spec=base_async.BaseAsyncConnectionContext) msg = {"id": 1, "type": base.GQL_CONNECTION_INIT, "payload": ""} await server.on_message(context, json.dumps(msg)) server.process_message.assert_called_with(context, msg) -async def test_message_invalid(server): +async def test_message_invalid(server: TstServer): server.send_error = AsyncMock() - await server.on_message(connection_context=None, message="'not-json") + context = AsyncMock(spec=base_async.BaseAsyncConnectionContext) + await server.on_message(context, message="'not-json") assert server.send_error.called -async def test_resolver(server): +async def test_resolver(server: TstServer): server.send_message = AsyncMock() + context = AsyncMock(spec=base_async.BaseAsyncConnectionContext) result = mock.Mock() result.data = {"test": [1, 2]} result.errors = None await server.send_execution_result( - connection_context=None, op_id=1, execution_result=result + context, op_id=1, execution_result=result ) assert server.send_message.called @pytest.mark.asyncio -async def test_resolver_with_promise(server): +async def test_resolver_with_promise(server: TstServer): server.send_message = AsyncMock() + context = AsyncMock(spec=base_async.BaseAsyncConnectionContext) result = mock.Mock() result.data = {"test": [1, promise.Promise(lambda resolve, reject: resolve(2))]} result.errors = None await server.send_execution_result( - connection_context=None, op_id=1, execution_result=result + context, op_id=1, execution_result=result ) assert server.send_message.called assert result.data == {'test': [1, 2]} -async def test_resolver_with_nested_promise(server): +async def test_resolver_with_nested_promise(server: TstServer): server.send_message = AsyncMock() + context = AsyncMock(spec=base_async.BaseAsyncConnectionContext) result = mock.Mock() inner = promise.Promise(lambda resolve, reject: resolve(2)) outer = promise.Promise(lambda resolve, reject: resolve({'in': inner})) result.data = {"test": [1, outer]} result.errors = None await server.send_execution_result( - connection_context=None, op_id=1, execution_result=result + context, op_id=1, execution_result=result ) assert server.send_message.called assert result.data == {'test': [1, {'in': 2}]} diff --git a/tests/test_graphql_ws.py b/tests/test_graphql_ws.py index 3b85c49..bb30d4a 100644 --- a/tests/test_graphql_ws.py +++ b/tests/test_graphql_ws.py @@ -77,15 +77,20 @@ def test_terminate(self, ss, cc): ss.process_message(cc, {"id": "1", "type": constants.GQL_CONNECTION_TERMINATE}) ss.on_connection_terminate.assert_called_with(cc, "1") - def test_start(self, ss, cc): + @pytest.mark.parametrize( + "transport_ws_protocol,expected_type", + ((False, constants.GQL_START), (True, constants.GQL_SUBSCRIBE)), + ) + def test_start(self, ss, cc, transport_ws_protocol, expected_type): ss.get_graphql_params = mock.Mock() ss.get_graphql_params.return_value = {"params": True} cc.has_operation = mock.Mock() cc.has_operation.return_value = False + cc.transport_ws_protocol = transport_ws_protocol ss.unsubscribe = mock.Mock() ss.on_start = mock.Mock() ss.process_message( - cc, {"id": "1", "type": constants.GQL_START, "payload": {"a": "b"}} + cc, {"id": "1", "type": expected_type, "payload": {"a": "b"}} ) assert not ss.unsubscribe.called ss.on_start.assert_called_with(cc, "1", {"params": True}) @@ -117,9 +122,32 @@ def test_start_bad_graphql_params(self, ss, cc): assert isinstance(ss.send_error.call_args[0][2], Exception) assert not ss.on_start.called - def test_stop(self, ss, cc): + @pytest.mark.parametrize( + "transport_ws_protocol,stop_type,invalid_stop_type", + ( + (False, constants.GQL_STOP, constants.GQL_COMPLETE), + (True, constants.GQL_COMPLETE, constants.GQL_STOP), + ), + ) + def test_stop( + self, + ss, + cc, + transport_ws_protocol, + stop_type, + invalid_stop_type, + ): ss.on_stop = mock.Mock() - ss.process_message(cc, {"id": "1", "type": constants.GQL_STOP}) + ss.send_error = mock.Mock() + cc.transport_ws_protocol = transport_ws_protocol + + ss.process_message(cc, {"id": "1", "type": invalid_stop_type}) + assert ss.send_error.called + assert ss.send_error.call_args[0][:2] == (cc, "1") + assert isinstance(ss.send_error.call_args[0][2], Exception) + assert not ss.on_stop.called + + ss.process_message(cc, {"id": "1", "type": stop_type}) ss.on_stop.assert_called_with(cc, "1") def test_invalid(self, ss, cc): @@ -165,13 +193,18 @@ def test_build_message_partial(ss): ss.build_message(id=None, op_type=None, payload=None) -def test_send_execution_result(ss): +@pytest.mark.parametrize( + "transport_ws_protocol,expected_type", + ((False, constants.GQL_DATA), (True, constants.GQL_NEXT)), +) +def test_send_execution_result(ss, cc, transport_ws_protocol, expected_type): + cc.transport_ws_protocol = transport_ws_protocol ss.execution_result_to_dict = mock.Mock() ss.execution_result_to_dict.return_value = {"res": "ult"} ss.send_message = mock.Mock() ss.send_message.return_value = "returned" assert "returned" == ss.send_execution_result(cc, "1", "result") - ss.send_message.assert_called_with(cc, "1", constants.GQL_DATA, {"res": "ult"}) + ss.send_message.assert_called_with(cc, "1", expected_type, {"res": "ult"}) def test_execution_result_to_dict(ss):