From c00066babc8de09af521e3ffcbf360c74808e6b4 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Mon, 18 May 2020 02:54:17 +1200 Subject: [PATCH 01/42] Split base classes into sync and async Deduplicate code --- graphql_ws/__init__.py | 3 - graphql_ws/aiohttp.py | 93 +++----------------- graphql_ws/base.py | 150 ++++++++++++++------------------- graphql_ws/base_async.py | 118 ++++++++++++++++++++++++++ graphql_ws/base_sync.py | 88 +++++++++++++++++++ graphql_ws/django_channels.py | 97 ++------------------- graphql_ws/gevent.py | 79 ++--------------- graphql_ws/observable_aiter.py | 47 +---------- graphql_ws/websockets_lib.py | 93 +++----------------- 9 files changed, 312 insertions(+), 456 deletions(-) create mode 100644 graphql_ws/base_async.py create mode 100644 graphql_ws/base_sync.py diff --git a/graphql_ws/__init__.py b/graphql_ws/__init__.py index 44c7dc3..793831a 100644 --- a/graphql_ws/__init__.py +++ b/graphql_ws/__init__.py @@ -5,6 +5,3 @@ __author__ = """Syrus Akbary""" __email__ = 'me@syrusakbary.com' __version__ = '0.3.1' - - -from .base import BaseConnectionContext, BaseSubscriptionServer # noqa: F401 diff --git a/graphql_ws/aiohttp.py b/graphql_ws/aiohttp.py index 363ca67..49e0a5e 100644 --- a/graphql_ws/aiohttp.py +++ b/graphql_ws/aiohttp.py @@ -1,23 +1,13 @@ -from inspect import isawaitable -from asyncio import ensure_future, wait, shield +import json +from asyncio import ensure_future, shield from aiohttp import WSMsgType -from graphql.execution.executors.asyncio import AsyncioExecutor -from .base import ( - ConnectionClosedException, BaseConnectionContext, BaseSubscriptionServer) -from .observable_aiter import setup_observable_extension +from .base import ConnectionClosedException +from .base_async import BaseAsyncConnectionContext, BaseAsyncSubscriptionServer -from .constants import ( - GQL_CONNECTION_ACK, - GQL_CONNECTION_ERROR, - GQL_COMPLETE -) -setup_observable_extension() - - -class AiohttpConnectionContext(BaseConnectionContext): +class AiohttpConnectionContext(BaseAsyncConnectionContext): async def receive(self): msg = await self.ws.receive() if msg.type == WSMsgType.TEXT: @@ -32,7 +22,7 @@ async def receive(self): async def send(self, data): if self.closed: return - await self.ws.send_str(data) + await self.ws.send_str(json.dumps(data)) @property def closed(self): @@ -42,21 +32,10 @@ async def close(self, code): await self.ws.close(code=code) -class AiohttpSubscriptionServer(BaseSubscriptionServer): - def __init__(self, schema, keep_alive=True, loop=None): - self.loop = loop - super().__init__(schema, keep_alive) - - def get_graphql_params(self, *args, **kwargs): - params = super(AiohttpSubscriptionServer, - self).get_graphql_params(*args, **kwargs) - return dict(params, return_promise=True, - executor=AsyncioExecutor(loop=self.loop)) - +class AiohttpSubscriptionServer(BaseAsyncSubscriptionServer): async def _handle(self, ws, request_context=None): connection_context = AiohttpConnectionContext(ws, request_context) await self.on_open(connection_context) - pending = set() while True: try: if connection_context.closed: @@ -64,59 +43,13 @@ async def _handle(self, ws, request_context=None): message = await connection_context.receive() except ConnectionClosedException: break - finally: - if pending: - (_, pending) = await wait(pending, timeout=0, loop=self.loop) - task = ensure_future( - self.on_message(connection_context, message), loop=self.loop) - pending.add(task) - - self.on_close(connection_context) - for task in pending: - task.cancel() + connection_context.remember_task( + ensure_future( + self.on_message(connection_context, message), loop=self.loop + ) + ) + await self.on_close(connection_context) async def handle(self, ws, request_context=None): await shield(self._handle(ws, request_context), loop=self.loop) - - async def on_open(self, connection_context): - pass - - def on_close(self, connection_context): - remove_operations = list(connection_context.operations.keys()) - for op_id in remove_operations: - self.unsubscribe(connection_context, op_id) - - async def on_connect(self, connection_context, payload): - pass - - async def on_connection_init(self, connection_context, op_id, payload): - try: - await self.on_connect(connection_context, payload) - await self.send_message(connection_context, op_type=GQL_CONNECTION_ACK) - except Exception as e: - await self.send_error(connection_context, op_id, e, GQL_CONNECTION_ERROR) - await connection_context.close(1011) - - async def on_start(self, connection_context, op_id, params): - execution_result = self.execute( - connection_context.request_context, params) - - if isawaitable(execution_result): - execution_result = await execution_result - - if not hasattr(execution_result, '__aiter__'): - await self.send_execution_result( - connection_context, op_id, execution_result) - else: - iterator = await execution_result.__aiter__() - connection_context.register_operation(op_id, iterator) - async for single_result in iterator: - if not connection_context.has_operation(op_id): - break - await self.send_execution_result( - connection_context, op_id, single_result) - await self.send_message(connection_context, op_id, GQL_COMPLETE) - - async def on_stop(self, connection_context, op_id): - self.unsubscribe(connection_context, op_id) diff --git a/graphql_ws/base.py b/graphql_ws/base.py index f3aa1e7..d146419 100644 --- a/graphql_ws/base.py +++ b/graphql_ws/base.py @@ -1,16 +1,16 @@ import json from collections import OrderedDict -from graphql import graphql, format_error +from graphql import format_error from .constants import ( + GQL_CONNECTION_ERROR, GQL_CONNECTION_INIT, GQL_CONNECTION_TERMINATE, + GQL_DATA, + GQL_ERROR, GQL_START, GQL_STOP, - GQL_ERROR, - GQL_CONNECTION_ERROR, - GQL_DATA ) @@ -51,33 +51,16 @@ def close(self, code): class BaseSubscriptionServer(object): + graphql_executor = None def __init__(self, schema, keep_alive=True): self.schema = schema self.keep_alive = keep_alive - def get_graphql_params(self, connection_context, payload): - return { - 'request_string': payload.get('query'), - 'variable_values': payload.get('variables'), - 'operation_name': payload.get('operationName'), - 'context_value': payload.get('context'), - } - - def build_message(self, id, op_type, payload): - message = {} - if id is not None: - message['id'] = id - if op_type is not None: - message['type'] = op_type - if payload is not None: - message['payload'] = payload - return message - def process_message(self, connection_context, parsed_message): - op_id = parsed_message.get('id') - op_type = parsed_message.get('type') - payload = parsed_message.get('payload') + op_id = parsed_message.get("id") + op_type = parsed_message.get("type") + payload = parsed_message.get("payload") if op_type == GQL_CONNECTION_INIT: return self.on_connection_init(connection_context, op_id, payload) @@ -92,7 +75,8 @@ def process_message(self, connection_context, parsed_message): if not isinstance(params, dict): error = Exception( "Invalid params returned from get_graphql_params!" - " Return values must be a dict.") + " Return values must be a dict." + ) return self.send_error(connection_context, op_id, error) # If we already have a subscription with this id, unsubscribe from @@ -100,14 +84,54 @@ def process_message(self, connection_context, parsed_message): if connection_context.has_operation(op_id): self.unsubscribe(connection_context, op_id) + params = self.get_graphql_params(connection_context, payload) return self.on_start(connection_context, op_id, params) elif op_type == GQL_STOP: return self.on_stop(connection_context, op_id) else: - return self.send_error(connection_context, op_id, Exception( - "Invalid message type: {}.".format(op_type))) + return self.send_error( + connection_context, + op_id, + Exception("Invalid message type: {}.".format(op_type)), + ) + + def on_connection_init(self, connection_context, op_id, payload): + raise NotImplementedError("on_connection_init method not implemented") + + def on_connection_terminate(self, connection_context, op_id): + return connection_context.close(1011) + + def get_graphql_params(self, connection_context, payload): + return { + "request_string": payload.get("query"), + "variable_values": payload.get("variables"), + "operation_name": payload.get("operationName"), + "context_value": payload.get("context"), + "executor": self.graphql_executor(), + } + + def on_open(self, connection_context): + raise NotImplementedError("on_open method not implemented") + + def on_stop(self, connection_context, op_id): + raise NotImplementedError("on_stop method not implemented") + + def send_message(self, connection_context, op_id=None, op_type=None, payload=None): + message = self.build_message(op_id, op_type, payload) + assert message, "You need to send at least one thing" + return connection_context.send(message) + + def build_message(self, id, op_type, payload): + message = {} + if id is not None: + message["id"] = id + if op_type is not None: + message["type"] = op_type + if payload is not None: + message["payload"] = payload + return message def send_execution_result(self, connection_context, op_id, execution_result): result = self.execution_result_to_dict(execution_result) @@ -116,86 +140,34 @@ def send_execution_result(self, connection_context, op_id, execution_result): def execution_result_to_dict(self, execution_result): result = OrderedDict() if execution_result.data: - result['data'] = execution_result.data + result["data"] = execution_result.data if execution_result.errors: - result['errors'] = [format_error(error) - for error in execution_result.errors] + result["errors"] = [ + format_error(error) for error in execution_result.errors + ] return result - def send_message(self, connection_context, op_id=None, op_type=None, payload=None): - message = self.build_message(op_id, op_type, payload) - assert message, "You need to send at least one thing" - json_message = json.dumps(message) - return connection_context.send(json_message) - def send_error(self, connection_context, op_id, error, error_type=None): if error_type is None: error_type = GQL_ERROR assert error_type in [GQL_CONNECTION_ERROR, GQL_ERROR], ( - 'error_type should be one of the allowed error messages' - ' GQL_CONNECTION_ERROR or GQL_ERROR' - ) - - error_payload = { - 'message': str(error) - } - - return self.send_message( - connection_context, - op_id, - error_type, - error_payload + "error_type should be one of the allowed error messages" + " GQL_CONNECTION_ERROR or GQL_ERROR" ) - def unsubscribe(self, connection_context, op_id): - if connection_context.has_operation(op_id): - # Close async iterator - connection_context.get_operation(op_id).dispose() - # Close operation - connection_context.remove_operation(op_id) - self.on_operation_complete(connection_context, op_id) + error_payload = {"message": str(error)} - def on_operation_complete(self, connection_context, op_id): - pass - - def on_connection_terminate(self, connection_context, op_id): - return connection_context.close(1011) - - def execute(self, request_context, params): - return graphql( - self.schema, **dict(params, allow_subscriptions=True)) - - def handle(self, ws, request_context=None): - raise NotImplementedError("handle method not implemented") + return self.send_message(connection_context, op_id, error_type, error_payload) def on_message(self, connection_context, message): try: if not isinstance(message, dict): parsed_message = json.loads(message) - assert isinstance( - parsed_message, dict), "Payload must be an object." + assert isinstance(parsed_message, dict), "Payload must be an object." else: parsed_message = message except Exception as e: return self.send_error(connection_context, None, e) return self.process_message(connection_context, parsed_message) - - def on_open(self, connection_context): - raise NotImplementedError("on_open method not implemented") - - def on_connect(self, connection_context, payload): - raise NotImplementedError("on_connect method not implemented") - - def on_close(self, connection_context): - raise NotImplementedError("on_close method not implemented") - - def on_connection_init(self, connection_context, op_id, payload): - raise NotImplementedError("on_connection_init method not implemented") - - def on_stop(self, connection_context, op_id): - raise NotImplementedError("on_stop method not implemented") - - def on_start(self, connection_context, op_id, params): - raise NotImplementedError("on_start method not implemented") diff --git a/graphql_ws/base_async.py b/graphql_ws/base_async.py new file mode 100644 index 0000000..d067a8d --- /dev/null +++ b/graphql_ws/base_async.py @@ -0,0 +1,118 @@ +import asyncio +from abc import ABC, abstractmethod +from inspect import isawaitable +from weakref import WeakSet + +from graphql.execution.executors.asyncio import AsyncioExecutor + +from graphql_ws import base + +from .constants import GQL_COMPLETE, GQL_CONNECTION_ACK, GQL_CONNECTION_ERROR +from .observable_aiter import setup_observable_extension + +setup_observable_extension() + + +class BaseAsyncConnectionContext(base.BaseConnectionContext, ABC): + def __init__(self, ws, request_context=None): + super().__init__(ws, request_context=request_context) + self.pending_tasks = WeakSet() + + @abstractmethod + async def receive(self): + raise NotImplementedError("receive method not implemented") + + @abstractmethod + async def send(self, data): + ... + + @property + @abstractmethod + def closed(self): + ... + + @abstractmethod + async def close(self, code): + ... + + def remember_task(self, task): + self.pending_tasks.add(asyncio.ensure_future(task)) + # Clear completed tasks + self.pending_tasks -= WeakSet( + task for task in self.pending_tasks if task.done() + ) + + +class BaseAsyncSubscriptionServer(base.BaseSubscriptionServer, ABC): + graphql_executor = AsyncioExecutor + + def __init__(self, schema, keep_alive=True, loop=None): + self.loop = loop + super().__init__(schema, keep_alive) + + @abstractmethod + async def handle(self, ws, request_context=None): + ... + + def process_message(self, connection_context, parsed_message): + task = asyncio.ensure_future( + super().process_message(connection_context, parsed_message) + ) + connection_context.pending.add(task) + return task + + async def send_message(self, *args, **kwargs): + await super().send_message(*args, **kwargs) + + async def on_open(self, connection_context): + pass + + async def on_connect(self, connection_context, payload): + pass + + async def on_connection_init(self, connection_context, op_id, payload): + try: + await self.on_connect(connection_context, payload) + await self.send_message(connection_context, op_type=GQL_CONNECTION_ACK) + except Exception as e: + await self.send_error(connection_context, op_id, e, GQL_CONNECTION_ERROR) + await connection_context.close(1011) + + async def on_start(self, connection_context, op_id, params): + execution_result = self.execute(connection_context.request_context, params) + + if isawaitable(execution_result): + execution_result = await execution_result + + if hasattr(execution_result, "__aiter__"): + iterator = await execution_result.__aiter__() + connection_context.register_operation(op_id, iterator) + async for single_result in iterator: + if not connection_context.has_operation(op_id): + break + await self.send_execution_result( + connection_context, op_id, single_result + ) + else: + await self.send_execution_result( + connection_context, op_id, execution_result + ) + await self.send_message(connection_context, op_id, GQL_COMPLETE) + await self.on_operation_complete(connection_context, op_id) + + async def on_close(self, connection_context): + awaitables = tuple( + self.unsubscribe(connection_context, op_id) + for op_id in connection_context.operations + ) + tuple(task.cancel() for task in connection_context.pending_tasks) + if awaitables: + await asyncio.gather(*awaitables, loop=self.loop) + + async def on_stop(self, connection_context, op_id): + await self.unsubscribe(connection_context, op_id) + + async def unsubscribe(self, connection_context, op_id): + await super().unsubscribe(connection_context, op_id) + + async def on_operation_complete(self, connection_context, op_id): + pass diff --git a/graphql_ws/base_sync.py b/graphql_ws/base_sync.py new file mode 100644 index 0000000..b7cb412 --- /dev/null +++ b/graphql_ws/base_sync.py @@ -0,0 +1,88 @@ +from graphql import graphql +from graphql.execution.executors.sync import SyncExecutor +from rx import Observable, Observer + +from .base import BaseSubscriptionServer +from .constants import GQL_CONNECTION_ACK, GQL_CONNECTION_ERROR + + +class BaseSyncSubscriptionServer(BaseSubscriptionServer): + graphql_executor = SyncExecutor + + def unsubscribe(self, connection_context, op_id): + if connection_context.has_operation(op_id): + # Close async iterator + connection_context.get_operation(op_id).dispose() + # Close operation + connection_context.remove_operation(op_id) + self.on_operation_complete(connection_context, op_id) + + def on_operation_complete(self, connection_context, op_id): + pass + + def execute(self, request_context, params): + return graphql(self.schema, **dict(params, allow_subscriptions=True)) + + def handle(self, ws, request_context=None): + raise NotImplementedError("handle method not implemented") + + def on_open(self, connection_context): + pass + + def on_connect(self, connection_context, payload): + pass + + def on_close(self, connection_context): + remove_operations = list(connection_context.operations.keys()) + for op_id in remove_operations: + self.unsubscribe(connection_context, op_id) + + def on_connection_init(self, connection_context, op_id, payload): + try: + self.on_connect(connection_context, payload) + self.send_message(connection_context, op_type=GQL_CONNECTION_ACK) + + except Exception as e: + self.send_error(connection_context, op_id, e, GQL_CONNECTION_ERROR) + connection_context.close(1011) + + def on_stop(self, connection_context, op_id): + self.unsubscribe(connection_context, op_id) + + def on_start(self, connection_context, op_id, params): + try: + execution_result = self.execute(connection_context.request_context, params) + assert isinstance( + execution_result, Observable + ), "A subscription must return an observable" + execution_result.subscribe( + SubscriptionObserver( + connection_context, + op_id, + self.send_execution_result, + self.send_error, + self.on_close, + ) + ) + except Exception as e: + self.send_error(connection_context, op_id, str(e)) + + +class SubscriptionObserver(Observer): + def __init__( + self, connection_context, op_id, send_execution_result, send_error, on_close + ): + self.connection_context = connection_context + self.op_id = op_id + self.send_execution_result = send_execution_result + self.send_error = send_error + self.on_close = on_close + + def on_next(self, value): + self.send_execution_result(self.connection_context, self.op_id, value) + + def on_completed(self): + self.on_close(self.connection_context) + + def on_error(self, error): + self.send_error(self.connection_context, self.op_id, error) diff --git a/graphql_ws/django_channels.py b/graphql_ws/django_channels.py index 61a7247..fbee47b 100644 --- a/graphql_ws/django_channels.py +++ b/graphql_ws/django_channels.py @@ -1,94 +1,30 @@ import json -from rx import Observer, Observable -from graphql.execution.executors.sync import SyncExecutor - from channels.generic.websockets import JsonWebsocketConsumer from graphene_django.settings import graphene_settings -from .base import BaseConnectionContext, BaseSubscriptionServer -from .constants import GQL_CONNECTION_ACK, GQL_CONNECTION_ERROR +from .base import BaseConnectionContext +from .base_sync import BaseSyncSubscriptionServer class DjangoChannelConnectionContext(BaseConnectionContext): - def __init__(self, message, request_context=None): self.message = message self.operations = {} self.request_context = request_context def send(self, data): - self.message.reply_channel.send(data) + self.message.reply_channel.send({"text": json.dumps(data)}) def close(self, reason): - data = { - 'close': True, - 'text': reason - } + data = {"close": True, "text": reason} self.message.reply_channel.send(data) -class DjangoChannelSubscriptionServer(BaseSubscriptionServer): - - def get_graphql_params(self, *args, **kwargs): - params = super(DjangoChannelSubscriptionServer, - self).get_graphql_params(*args, **kwargs) - return dict(params, executor=SyncExecutor()) - +class DjangoChannelSubscriptionServer(BaseSyncSubscriptionServer): def handle(self, message, connection_context): self.on_message(connection_context, message) - def send_message(self, connection_context, op_id=None, op_type=None, payload=None): - message = {} - if op_id is not None: - message['id'] = op_id - if op_type is not None: - message['type'] = op_type - if payload is not None: - message['payload'] = payload - - assert message, "You need to send at least one thing" - return connection_context.send({'text': json.dumps(message)}) - - def on_open(self, connection_context): - pass - - def on_connect(self, connection_context, payload): - pass - - def on_connection_init(self, connection_context, op_id, payload): - try: - self.on_connect(connection_context, payload) - self.send_message(connection_context, op_type=GQL_CONNECTION_ACK) - - except Exception as e: - self.send_error(connection_context, op_id, e, GQL_CONNECTION_ERROR) - connection_context.close(1011) - - def on_start(self, connection_context, op_id, params): - try: - execution_result = self.execute( - connection_context.request_context, params) - assert isinstance(execution_result, Observable), \ - "A subscription must return an observable" - execution_result.subscribe(SubscriptionObserver( - connection_context, - op_id, - self.send_execution_result, - self.send_error, - self.on_close - )) - except Exception as e: - self.send_error(connection_context, op_id, str(e)) - - def on_close(self, connection_context): - remove_operations = list(connection_context.operations.keys()) - for op_id in remove_operations: - self.unsubscribe(connection_context, op_id) - - def on_stop(self, connection_context, op_id): - self.unsubscribe(connection_context, op_id) - class GraphQLSubscriptionConsumer(JsonWebsocketConsumer): http_user_and_session = True @@ -104,26 +40,7 @@ def receive(self, content, **_kwargs): """ self.connection_context = DjangoChannelConnectionContext(self.message) self.subscription_server = DjangoChannelSubscriptionServer( - graphene_settings.SCHEMA) + graphene_settings.SCHEMA + ) self.subscription_server.on_open(self.connection_context) self.subscription_server.handle(content, self.connection_context) - - -class SubscriptionObserver(Observer): - - def __init__(self, connection_context, op_id, - send_execution_result, send_error, on_close): - self.connection_context = connection_context - self.op_id = op_id - self.send_execution_result = send_execution_result - self.send_error = send_error - self.on_close = on_close - - def on_next(self, value): - self.send_execution_result(self.connection_context, self.op_id, value) - - def on_completed(self): - self.on_close(self.connection_context) - - def on_error(self, error): - self.send_error(self.connection_context, self.op_id, error) diff --git a/graphql_ws/gevent.py b/graphql_ws/gevent.py index aadbe64..b7d6849 100644 --- a/graphql_ws/gevent.py +++ b/graphql_ws/gevent.py @@ -1,15 +1,15 @@ from __future__ import absolute_import -from rx import Observer, Observable -from graphql.execution.executors.sync import SyncExecutor +import json from .base import ( - ConnectionClosedException, BaseConnectionContext, BaseSubscriptionServer) -from .constants import GQL_CONNECTION_ACK, GQL_CONNECTION_ERROR + BaseConnectionContext, + ConnectionClosedException, +) +from .base_sync import BaseSyncSubscriptionServer class GeventConnectionContext(BaseConnectionContext): - def receive(self): msg = self.ws.receive() return msg @@ -17,7 +17,7 @@ def receive(self): def send(self, data): if self.closed: return - self.ws.send(data) + self.ws.send(json.dumps(data)) @property def closed(self): @@ -27,13 +27,7 @@ def close(self, code): self.ws.close(code) -class GeventSubscriptionServer(BaseSubscriptionServer): - - def get_graphql_params(self, *args, **kwargs): - params = super(GeventSubscriptionServer, - self).get_graphql_params(*args, **kwargs) - return dict(params, executor=SyncExecutor()) - +class GeventSubscriptionServer(BaseSyncSubscriptionServer): def handle(self, ws, request_context=None): connection_context = GeventConnectionContext(ws, request_context) self.on_open(connection_context) @@ -46,62 +40,3 @@ def handle(self, ws, request_context=None): self.on_close(connection_context) return self.on_message(connection_context, message) - - def on_open(self, connection_context): - pass - - def on_connect(self, connection_context, payload): - pass - - def on_close(self, connection_context): - remove_operations = list(connection_context.operations.keys()) - for op_id in remove_operations: - self.unsubscribe(connection_context, op_id) - - def on_connection_init(self, connection_context, op_id, payload): - try: - self.on_connect(connection_context, payload) - self.send_message(connection_context, op_type=GQL_CONNECTION_ACK) - - except Exception as e: - self.send_error(connection_context, op_id, e, GQL_CONNECTION_ERROR) - connection_context.close(1011) - - def on_start(self, connection_context, op_id, params): - try: - execution_result = self.execute( - connection_context.request_context, params) - assert isinstance(execution_result, Observable), \ - "A subscription must return an observable" - execution_result.subscribe(SubscriptionObserver( - connection_context, - op_id, - self.send_execution_result, - self.send_error, - self.on_close - )) - except Exception as e: - self.send_error(connection_context, op_id, str(e)) - - def on_stop(self, connection_context, op_id): - self.unsubscribe(connection_context, op_id) - - -class SubscriptionObserver(Observer): - - def __init__(self, connection_context, op_id, - send_execution_result, send_error, on_close): - self.connection_context = connection_context - self.op_id = op_id - self.send_execution_result = send_execution_result - self.send_error = send_error - self.on_close = on_close - - def on_next(self, value): - self.send_execution_result(self.connection_context, self.op_id, value) - - def on_completed(self): - self.on_close(self.connection_context) - - def on_error(self, error): - self.send_error(self.connection_context, self.op_id, error) diff --git a/graphql_ws/observable_aiter.py b/graphql_ws/observable_aiter.py index 0bd1a59..424d95f 100644 --- a/graphql_ws/observable_aiter.py +++ b/graphql_ws/observable_aiter.py @@ -1,7 +1,7 @@ from asyncio import Future -from rx.internal import extensionmethod from rx.core import Observable +from rx.internal import extensionmethod async def __aiter__(self): @@ -13,15 +13,11 @@ def __init__(self): self.future = Future() self.disposable = source.materialize().subscribe(self.on_next) - # self.disposed = False def __aiter__(self): return self def dispose(self): - # self.future.cancel() - # self.disposed = True - # self.future.set_exception(StopAsyncIteration) self.disposable.dispose() def feeder(self): @@ -30,11 +26,11 @@ def feeder(self): notification = self.notifications.pop(0) kind = notification.kind - if kind == 'N': + if kind == "N": self.future.set_result(notification.value) - if kind == 'E': + if kind == "E": self.future.set_exception(notification.exception) - if kind == 'C': + if kind == "C": self.future.set_exception(StopAsyncIteration) def on_next(self, notification): @@ -42,8 +38,6 @@ def on_next(self, notification): self.feeder() async def __anext__(self): - # if self.disposed: - # raise StopAsyncIteration self.feeder() value = await self.future @@ -53,38 +47,5 @@ async def __anext__(self): return AIterator() -# def __aiter__(self, sentinel=None): -# loop = get_event_loop() -# future = [Future()] -# notifications = [] - -# def feeder(): -# if not len(notifications) or future[0].done(): -# return -# notification = notifications.pop(0) -# if notification.kind == "E": -# future[0].set_exception(notification.exception) -# elif notification.kind == "C": -# future[0].set_exception(StopIteration(sentinel)) -# else: -# future[0].set_result(notification.value) - -# def on_next(value): -# """Takes on_next values and appends them to the notification queue""" -# notifications.append(value) -# loop.call_soon(feeder) - -# self.materialize().subscribe(on_next) - -# @asyncio.coroutine -# def gen(): -# """Generator producing futures""" -# loop.call_soon(feeder) -# future[0] = Future() -# return future[0] - -# return gen - - def setup_observable_extension(): extensionmethod(Observable)(__aiter__) diff --git a/graphql_ws/websockets_lib.py b/graphql_ws/websockets_lib.py index 7e78d5d..93ad76f 100644 --- a/graphql_ws/websockets_lib.py +++ b/graphql_ws/websockets_lib.py @@ -1,19 +1,13 @@ -from inspect import isawaitable -from asyncio import ensure_future, wait, shield -from websockets import ConnectionClosed -from graphql.execution.executors.asyncio import AsyncioExecutor - -from .base import ( - ConnectionClosedException, BaseConnectionContext, BaseSubscriptionServer) -from .observable_aiter import setup_observable_extension +import json +from asyncio import ensure_future, shield -from .constants import ( - GQL_CONNECTION_ACK, GQL_CONNECTION_ERROR, GQL_COMPLETE) +from websockets import ConnectionClosed -setup_observable_extension() +from .base import ConnectionClosedException +from .base_async import BaseAsyncConnectionContext, BaseAsyncSubscriptionServer -class WsLibConnectionContext(BaseConnectionContext): +class WsLibConnectionContext(BaseAsyncConnectionContext): async def receive(self): try: msg = await self.ws.recv() @@ -24,7 +18,7 @@ async def receive(self): async def send(self, data): if self.closed: return - await self.ws.send(data) + await self.ws.send(json.dumps(data)) @property def closed(self): @@ -34,21 +28,10 @@ async def close(self, code): await self.ws.close(code) -class WsLibSubscriptionServer(BaseSubscriptionServer): - def __init__(self, schema, keep_alive=True, loop=None): - self.loop = loop - super().__init__(schema, keep_alive) - - def get_graphql_params(self, *args, **kwargs): - params = super(WsLibSubscriptionServer, - self).get_graphql_params(*args, **kwargs) - return dict(params, return_promise=True, - executor=AsyncioExecutor(loop=self.loop)) - +class WsLibSubscriptionServer(BaseAsyncSubscriptionServer): async def _handle(self, ws, request_context): connection_context = WsLibConnectionContext(ws, request_context) await self.on_open(connection_context) - pending = set() while True: try: if connection_context.closed: @@ -56,61 +39,13 @@ async def _handle(self, ws, request_context): message = await connection_context.receive() except ConnectionClosedException: break - finally: - if pending: - (_, pending) = await wait(pending, timeout=0, loop=self.loop) - - task = ensure_future( - self.on_message(connection_context, message), loop=self.loop) - pending.add(task) - self.on_close(connection_context) - for task in pending: - task.cancel() + connection_context.remember_task( + ensure_future( + self.on_message(connection_context, message), loop=self.loop + ) + ) + await self.on_close(connection_context) async def handle(self, ws, request_context=None): await shield(self._handle(ws, request_context), loop=self.loop) - - async def on_open(self, connection_context): - pass - - def on_close(self, connection_context): - remove_operations = list(connection_context.operations.keys()) - for op_id in remove_operations: - self.unsubscribe(connection_context, op_id) - - async def on_connect(self, connection_context, payload): - pass - - async def on_connection_init(self, connection_context, op_id, payload): - try: - await self.on_connect(connection_context, payload) - await self.send_message( - connection_context, op_type=GQL_CONNECTION_ACK) - except Exception as e: - await self.send_error( - connection_context, op_id, e, GQL_CONNECTION_ERROR) - await connection_context.close(1011) - - async def on_start(self, connection_context, op_id, params): - execution_result = self.execute( - connection_context.request_context, params) - - if isawaitable(execution_result): - execution_result = await execution_result - - if not hasattr(execution_result, '__aiter__'): - await self.send_execution_result( - connection_context, op_id, execution_result) - else: - iterator = await execution_result.__aiter__() - connection_context.register_operation(op_id, iterator) - async for single_result in iterator: - if not connection_context.has_operation(op_id): - break - await self.send_execution_result( - connection_context, op_id, single_result) - await self.send_message(connection_context, op_id, GQL_COMPLETE) - - async def on_stop(self, connection_context, op_id): - self.unsubscribe(connection_context, op_id) From b6951ed404aa7bf4f8c79e0c4a1c20a3aca435f9 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Mon, 18 May 2020 02:54:24 +1200 Subject: [PATCH 02/42] Fix tests to match deduplication changes --- tests/test_aiohttp.py | 2 +- tests/test_django_channels.py | 2 +- tests/test_gevent.py | 4 ++-- tests/test_graphql_ws.py | 35 +++++++---------------------------- 4 files changed, 11 insertions(+), 32 deletions(-) diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index f20ca15..88a48d1 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -55,7 +55,7 @@ async def test_receive_closed(self, mock_ws): async def test_send(self, mock_ws): connection_context = AiohttpConnectionContext(ws=mock_ws) await connection_context.send("test") - mock_ws.send_str.assert_called_with("test") + mock_ws.send_str.assert_called_with('"test"') async def test_send_closed(self, mock_ws): mock_ws.closed = True diff --git a/tests/test_django_channels.py b/tests/test_django_channels.py index e7b054c..51ef6ae 100644 --- a/tests/test_django_channels.py +++ b/tests/test_django_channels.py @@ -14,7 +14,7 @@ def test_send(self): msg = mock.Mock() connection_context = DjangoChannelConnectionContext(message=msg) connection_context.send("test") - msg.reply_channel.send.assert_called_with("test") + msg.reply_channel.send.assert_called_with({'text': '"test"'}) def test_close(self): msg = mock.Mock() diff --git a/tests/test_gevent.py b/tests/test_gevent.py index f766c5a..a734970 100644 --- a/tests/test_gevent.py +++ b/tests/test_gevent.py @@ -17,8 +17,8 @@ def test_send(self): ws = mock.Mock() ws.closed = False connection_context = GeventConnectionContext(ws=ws) - connection_context.send("test") - ws.send.assert_called_with("test") + connection_context.send({"text": "test"}) + ws.send.assert_called_with('{"text": "test"}') def test_send_closed(self): ws = mock.Mock() diff --git a/tests/test_graphql_ws.py b/tests/test_graphql_ws.py index 3ba1120..65cbf91 100644 --- a/tests/test_graphql_ws.py +++ b/tests/test_graphql_ws.py @@ -5,8 +5,9 @@ import mock import pytest +from graphql.execution.executors.sync import SyncExecutor -from graphql_ws import base, constants +from graphql_ws import base, base_sync, constants @pytest.fixture @@ -18,7 +19,7 @@ def cc(): @pytest.fixture def ss(): - return base.BaseSubscriptionServer(schema=None) + return base_sync.BaseSyncSubscriptionServer(schema=None) class TestConnectionContextOperation: @@ -137,7 +138,9 @@ def test_get_graphql_params(ss, cc): "operationName": "query", "context": "ctx", } - assert ss.get_graphql_params(cc, payload) == { + params = ss.get_graphql_params(cc, payload) + assert isinstance(params.pop("executor"), SyncExecutor) + assert params == { "request_string": "req", "variable_values": "vars", "operation_name": "query", @@ -189,34 +192,10 @@ def test_send_message(ss, cc): cc.send = mock.Mock() cc.send.return_value = "returned" assert "returned" == ss.send_message(cc) - cc.send.assert_called_with('{"mess": "age"}') + cc.send.assert_called_with({"mess": "age"}) class TestSSNotImplemented: def test_handle(self, ss): with pytest.raises(NotImplementedError): ss.handle(ws=None, request_context=None) - - def test_on_open(self, ss): - with pytest.raises(NotImplementedError): - ss.on_open(connection_context=None) - - def test_on_connect(self, ss): - with pytest.raises(NotImplementedError): - ss.on_connect(connection_context=None, payload=None) - - def test_on_close(self, ss): - with pytest.raises(NotImplementedError): - ss.on_close(connection_context=None) - - def test_on_connection_init(self, ss): - with pytest.raises(NotImplementedError): - ss.on_connection_init(connection_context=None, op_id=None, payload=None) - - def test_on_stop(self, ss): - with pytest.raises(NotImplementedError): - ss.on_stop(connection_context=None, op_id=None) - - def test_on_start(self, ss): - with pytest.raises(NotImplementedError): - ss.on_start(connection_context=None, op_id=None, params=None) From 4fe4736896b88e7ee3652941a32d3becbeed455c Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Mon, 1 Jun 2020 23:48:10 +1200 Subject: [PATCH 03/42] Add some base tests --- setup.cfg | 5 ++++ tests/test_base.py | 62 ++++++++++++++++++++++++++++++++++++++++++++++ tox.ini | 1 + 3 files changed, 68 insertions(+) create mode 100644 tests/test_base.py diff --git a/setup.cfg b/setup.cfg index df50b23..b921bca 100644 --- a/setup.cfg +++ b/setup.cfg @@ -90,3 +90,8 @@ ignore = W503 [coverage:run] omit = .tox/* + +[coverage:report] +exclude_lines = + pragma: no cover + @abstract \ No newline at end of file diff --git a/tests/test_base.py b/tests/test_base.py new file mode 100644 index 0000000..2e78459 --- /dev/null +++ b/tests/test_base.py @@ -0,0 +1,62 @@ +try: + from unittest import mock +except ImportError: + import mock + +import json + +import pytest + +from graphql_ws import base + + +def test_not_implemented(): + server = base.BaseSubscriptionServer(schema=None) + with pytest.raises(NotImplementedError): + server.on_connection_init(connection_context=None, op_id=1, payload={}) + with pytest.raises(NotImplementedError): + server.on_open(connection_context=None) + with pytest.raises(NotImplementedError): + server.on_stop(connection_context=None, op_id=1) + + +def test_terminate(): + server = base.BaseSubscriptionServer(schema=None) + + context = mock.Mock() + server.on_connection_terminate(connection_context=context, op_id=1) + context.close.assert_called_with(1011) + + +def test_send_error(): + server = base.BaseSubscriptionServer(schema=None) + context = mock.Mock() + server.send_error(connection_context=context, op_id=1, error="test error") + context.send.assert_called_with( + {"id": 1, "type": "error", "payload": {"message": "test error"}} + ) + + +def test_message(): + server = base.BaseSubscriptionServer(schema=None) + server.process_message = mock.Mock() + context = mock.Mock() + msg = {"id": 1, "type": base.GQL_CONNECTION_INIT, "payload": ""} + server.on_message(context, msg) + server.process_message.assert_called_with(context, msg) + + +def test_message_str(): + server = base.BaseSubscriptionServer(schema=None) + server.process_message = mock.Mock() + context = mock.Mock() + msg = {"id": 1, "type": base.GQL_CONNECTION_INIT, "payload": ""} + server.on_message(context, json.dumps(msg)) + server.process_message.assert_called_with(context, msg) + + +def test_message_invalid(): + server = base.BaseSubscriptionServer(schema=None) + server.send_error = mock.Mock() + server.on_message(connection_context=None, message="'not-json") + assert server.send_error.called diff --git a/tox.ini b/tox.ini index 6de6deb..42d13b4 100644 --- a/tox.ini +++ b/tox.ini @@ -31,5 +31,6 @@ skip_install = true deps = coverage commands = coverage html + coverage xml coverage report --include="tests/*" --fail-under=100 -m coverage report --omit="tests/*" # --fail-under=90 -m \ No newline at end of file From f7ef09fef63b9a5408496988fe5292a78d6a5475 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Sat, 6 Jun 2020 17:08:24 +1200 Subject: [PATCH 04/42] Add base async tests --- tests/conftest.py | 4 +-- tests/test_base_async.py | 59 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 2 deletions(-) create mode 100644 tests/test_base_async.py diff --git a/tests/conftest.py b/tests/conftest.py index e551557..fa905b4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,6 @@ if sys.version_info > (3,): collect_ignore = ["test_django_channels.py"] if sys.version_info < (3, 6): - collect_ignore.append('test_gevent.py') + collect_ignore.append("test_gevent.py") else: - collect_ignore = ["test_aiohttp.py"] + collect_ignore = ["test_aiohttp.py", "test_base_async.py"] diff --git a/tests/test_base_async.py b/tests/test_base_async.py new file mode 100644 index 0000000..902acc7 --- /dev/null +++ b/tests/test_base_async.py @@ -0,0 +1,59 @@ +from unittest import mock + +import json + +import pytest + +from graphql_ws import base + + +def test_not_implemented(): + server = base.BaseSubscriptionServer(schema=None) + with pytest.raises(NotImplementedError): + server.on_connection_init(connection_context=None, op_id=1, payload={}) + with pytest.raises(NotImplementedError): + server.on_open(connection_context=None) + with pytest.raises(NotImplementedError): + server.on_stop(connection_context=None, op_id=1) + + +def test_terminate(): + server = base.BaseSubscriptionServer(schema=None) + + context = mock.Mock() + server.on_connection_terminate(connection_context=context, op_id=1) + context.close.assert_called_with(1011) + + +def test_send_error(): + server = base.BaseSubscriptionServer(schema=None) + context = mock.Mock() + server.send_error(connection_context=context, op_id=1, error="test error") + context.send.assert_called_with( + {"id": 1, "type": "error", "payload": {"message": "test error"}} + ) + + +def test_message(): + server = base.BaseSubscriptionServer(schema=None) + server.process_message = mock.Mock() + context = mock.Mock() + msg = {"id": 1, "type": base.GQL_CONNECTION_INIT, "payload": ""} + server.on_message(context, msg) + server.process_message.assert_called_with(context, msg) + + +def test_message_str(): + server = base.BaseSubscriptionServer(schema=None) + server.process_message = mock.Mock() + context = mock.Mock() + msg = {"id": 1, "type": base.GQL_CONNECTION_INIT, "payload": ""} + server.on_message(context, json.dumps(msg)) + server.process_message.assert_called_with(context, msg) + + +def test_message_invalid(): + server = base.BaseSubscriptionServer(schema=None) + server.send_error = mock.Mock() + server.on_message(connection_context=None, message="'not-json") + assert server.send_error.called From 944d94982abbe6588273b3cca9c186a701149f33 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Sat, 6 Jun 2020 17:08:34 +1200 Subject: [PATCH 05/42] Add django_channels tests --- tests/django_routing.py | 6 +++++ tests/test_django_channels.py | 46 +++++++++++++++++++++++++++++++++-- 2 files changed, 50 insertions(+), 2 deletions(-) create mode 100644 tests/django_routing.py diff --git a/tests/django_routing.py b/tests/django_routing.py new file mode 100644 index 0000000..9d01766 --- /dev/null +++ b/tests/django_routing.py @@ -0,0 +1,6 @@ +from channels.routing import route +from graphql_ws.django_channels import GraphQLSubscriptionConsumer + +channel_routing = [ + route("websocket.receive", GraphQLSubscriptionConsumer), +] diff --git a/tests/test_django_channels.py b/tests/test_django_channels.py index 51ef6ae..137d541 100644 --- a/tests/test_django_channels.py +++ b/tests/test_django_channels.py @@ -1,11 +1,35 @@ +from __future__ import unicode_literals + +import json + +import django import mock +from channels import Channel +from channels.test import ChannelTestCase, Client from django.conf import settings +from django.core.management import call_command -settings.configure() # noqa +settings.configure( + CHANNEL_LAYERS={ + "default": { + "BACKEND": "asgiref.inmemory.ChannelLayer", + "ROUTING": "tests.django_routing.channel_routing", + }, + }, + INSTALLED_APPS=[ + "django.contrib.sessions", + "django.contrib.contenttypes", + "django.contrib.auth", + ], + DATABASES={"default": {"ENGINE": "django.db.backends.sqlite3", "NAME": ":memory:"}}, +) +django.setup() +from graphql_ws.constants import GQL_CONNECTION_ACK, GQL_CONNECTION_INIT from graphql_ws.django_channels import ( DjangoChannelConnectionContext, DjangoChannelSubscriptionServer, + GraphQLSubscriptionConsumer, ) @@ -14,7 +38,7 @@ def test_send(self): msg = mock.Mock() connection_context = DjangoChannelConnectionContext(message=msg) connection_context.send("test") - msg.reply_channel.send.assert_called_with({'text': '"test"'}) + msg.reply_channel.send.assert_called_with({"text": '"test"'}) def test_close(self): msg = mock.Mock() @@ -25,3 +49,21 @@ def test_close(self): def test_subscription_server_smoke(): DjangoChannelSubscriptionServer(schema=None) + + +class TestConsumer(ChannelTestCase): + def test_connect(self): + call_command("migrate") + Channel("websocket.receive").send( + { + "path": "/graphql", + "order": 0, + "reply_channel": "websocket.receive", + "text": json.dumps({"type": GQL_CONNECTION_INIT, "id": 1}), + } + ) + message = self.get_next_message("websocket.receive", require=True) + GraphQLSubscriptionConsumer(message) + result = self.get_next_message("websocket.receive", require=True) + result_content = json.loads(result.content["text"]) + assert result_content == {"type": GQL_CONNECTION_ACK} From e4b3d9f9b4c5c0d94a8717ac4694d454bf31aeb3 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Tue, 19 May 2020 15:39:11 +1200 Subject: [PATCH 06/42] Remove a redundant check for an internal detail It'll still cause an exception on .execute() if somehow a third party subscription server did the wrong thing anyway --- graphql_ws/base.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/graphql_ws/base.py b/graphql_ws/base.py index d146419..ee82dec 100644 --- a/graphql_ws/base.py +++ b/graphql_ws/base.py @@ -71,14 +71,6 @@ def process_message(self, connection_context, parsed_message): elif op_type == GQL_START: assert isinstance(payload, dict), "The payload must be a dict" - params = self.get_graphql_params(connection_context, payload) - if not isinstance(params, dict): - error = Exception( - "Invalid params returned from get_graphql_params!" - " Return values must be a dict." - ) - return self.send_error(connection_context, op_id, error) - # If we already have a subscription with this id, unsubscribe from # it first if connection_context.has_operation(op_id): From 7b21f0fe5235c548dca27ce7a718dbab8f85a1ab Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Tue, 19 May 2020 15:52:49 +1200 Subject: [PATCH 07/42] Move execute back to base --- graphql_ws/base.py | 9 +++++++-- graphql_ws/base_async.py | 2 +- graphql_ws/base_sync.py | 6 +----- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/graphql_ws/base.py b/graphql_ws/base.py index ee82dec..30bd766 100644 --- a/graphql_ws/base.py +++ b/graphql_ws/base.py @@ -1,7 +1,7 @@ import json from collections import OrderedDict -from graphql import format_error +from graphql import format_error, graphql from .constants import ( GQL_CONNECTION_ERROR, @@ -57,6 +57,9 @@ def __init__(self, schema, keep_alive=True): self.schema = schema self.keep_alive = keep_alive + def execute(self, params): + return graphql(self.schema, **dict(params, allow_subscriptions=True)) + def process_message(self, connection_context, parsed_message): op_id = parsed_message.get("id") op_type = parsed_message.get("type") @@ -96,11 +99,13 @@ def on_connection_terminate(self, connection_context, op_id): return connection_context.close(1011) def get_graphql_params(self, connection_context, payload): + context = payload.get('context') or {} + context.setdefault('request_context', connection_context.request_context) return { "request_string": payload.get("query"), "variable_values": payload.get("variables"), "operation_name": payload.get("operationName"), - "context_value": payload.get("context"), + "context_value": context, "executor": self.graphql_executor(), } diff --git a/graphql_ws/base_async.py b/graphql_ws/base_async.py index d067a8d..3252196 100644 --- a/graphql_ws/base_async.py +++ b/graphql_ws/base_async.py @@ -79,7 +79,7 @@ async def on_connection_init(self, connection_context, op_id, payload): await connection_context.close(1011) async def on_start(self, connection_context, op_id, params): - execution_result = self.execute(connection_context.request_context, params) + execution_result = self.execute(params) if isawaitable(execution_result): execution_result = await execution_result diff --git a/graphql_ws/base_sync.py b/graphql_ws/base_sync.py index b7cb412..70bdbfc 100644 --- a/graphql_ws/base_sync.py +++ b/graphql_ws/base_sync.py @@ -1,4 +1,3 @@ -from graphql import graphql from graphql.execution.executors.sync import SyncExecutor from rx import Observable, Observer @@ -20,9 +19,6 @@ def unsubscribe(self, connection_context, op_id): def on_operation_complete(self, connection_context, op_id): pass - def execute(self, request_context, params): - return graphql(self.schema, **dict(params, allow_subscriptions=True)) - def handle(self, ws, request_context=None): raise NotImplementedError("handle method not implemented") @@ -51,7 +47,7 @@ def on_stop(self, connection_context, op_id): def on_start(self, connection_context, op_id, params): try: - execution_result = self.execute(connection_context.request_context, params) + execution_result = self.execute(params) assert isinstance( execution_result, Observable ), "A subscription must return an observable" From 75cad357b9b29ac247efda975bfd2f8a167e7c31 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Fri, 26 Jun 2020 09:46:47 +1200 Subject: [PATCH 08/42] Move operation unsubscription to BaseSubscriptionServer --- graphql_ws/base.py | 8 ++++++++ graphql_ws/base_async.py | 3 --- graphql_ws/base_sync.py | 12 +++--------- 3 files changed, 11 insertions(+), 12 deletions(-) diff --git a/graphql_ws/base.py b/graphql_ws/base.py index 30bd766..9fb931d 100644 --- a/graphql_ws/base.py +++ b/graphql_ws/base.py @@ -168,3 +168,11 @@ def on_message(self, connection_context, message): return self.send_error(connection_context, None, e) return self.process_message(connection_context, parsed_message) + + def unsubscribe(self, connection_context, op_id): + if connection_context.has_operation(op_id): + # Close async iterator + connection_context.get_operation(op_id).dispose() + # Close operation + connection_context.remove_operation(op_id) + self.on_operation_complete(connection_context, op_id) diff --git a/graphql_ws/base_async.py b/graphql_ws/base_async.py index 3252196..95d2f2b 100644 --- a/graphql_ws/base_async.py +++ b/graphql_ws/base_async.py @@ -111,8 +111,5 @@ async def on_close(self, connection_context): async def on_stop(self, connection_context, op_id): await self.unsubscribe(connection_context, op_id) - async def unsubscribe(self, connection_context, op_id): - await super().unsubscribe(connection_context, op_id) - async def on_operation_complete(self, connection_context, op_id): pass diff --git a/graphql_ws/base_sync.py b/graphql_ws/base_sync.py index 70bdbfc..56b4d42 100644 --- a/graphql_ws/base_sync.py +++ b/graphql_ws/base_sync.py @@ -8,14 +8,6 @@ class BaseSyncSubscriptionServer(BaseSubscriptionServer): graphql_executor = SyncExecutor - def unsubscribe(self, connection_context, op_id): - if connection_context.has_operation(op_id): - # Close async iterator - connection_context.get_operation(op_id).dispose() - # Close operation - connection_context.remove_operation(op_id) - self.on_operation_complete(connection_context, op_id) - def on_operation_complete(self, connection_context, op_id): pass @@ -51,7 +43,7 @@ def on_start(self, connection_context, op_id, params): assert isinstance( execution_result, Observable ), "A subscription must return an observable" - execution_result.subscribe( + disposable = execution_result.subscribe( SubscriptionObserver( connection_context, op_id, @@ -60,6 +52,8 @@ def on_start(self, connection_context, op_id, params): self.on_close, ) ) + connection_context.register_operation(op_id, disposable) + except Exception as e: self.send_error(connection_context, op_id, str(e)) From e904b1513b8ca47b91c7d03ea92fc5fdf85c99a7 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Fri, 26 Jun 2020 09:59:07 +1200 Subject: [PATCH 09/42] Black the example code --- examples/aiohttp/app.py | 21 +++-- examples/aiohttp/schema.py | 4 +- examples/aiohttp/template.py | 15 +-- .../django_subscriptions/asgi.py | 2 +- .../django_subscriptions/schema.py | 15 +-- .../django_subscriptions/settings.py | 91 +++++++++---------- .../django_subscriptions/template.py | 15 +-- .../django_subscriptions/urls.py | 15 +-- examples/flask_gevent/app.py | 14 +-- examples/flask_gevent/schema.py | 12 ++- examples/flask_gevent/template.py | 15 +-- examples/websockets_lib/app.py | 14 +-- examples/websockets_lib/schema.py | 4 +- examples/websockets_lib/template.py | 15 +-- 14 files changed, 128 insertions(+), 124 deletions(-) diff --git a/examples/aiohttp/app.py b/examples/aiohttp/app.py index 56dcaff..336a0c6 100644 --- a/examples/aiohttp/app.py +++ b/examples/aiohttp/app.py @@ -10,24 +10,25 @@ async def graphql_view(request): payload = await request.json() - response = await schema.execute(payload.get('query', ''), return_promise=True) + response = await schema.execute(payload.get("query", ""), return_promise=True) data = {} if response.errors: - data['errors'] = [format_error(e) for e in response.errors] + data["errors"] = [format_error(e) for e in response.errors] if response.data: - data['data'] = response.data + data["data"] = response.data jsondata = json.dumps(data,) - return web.Response(text=jsondata, headers={'Content-Type': 'application/json'}) + return web.Response(text=jsondata, headers={"Content-Type": "application/json"}) async def graphiql_view(request): - return web.Response(text=render_graphiql(), headers={'Content-Type': 'text/html'}) + return web.Response(text=render_graphiql(), headers={"Content-Type": "text/html"}) + subscription_server = AiohttpSubscriptionServer(schema) async def subscriptions(request): - ws = web.WebSocketResponse(protocols=('graphql-ws',)) + ws = web.WebSocketResponse(protocols=("graphql-ws",)) await ws.prepare(request) await subscription_server.handle(ws) @@ -35,9 +36,9 @@ async def subscriptions(request): app = web.Application() -app.router.add_get('/subscriptions', subscriptions) -app.router.add_get('/graphiql', graphiql_view) -app.router.add_get('/graphql', graphql_view) -app.router.add_post('/graphql', graphql_view) +app.router.add_get("/subscriptions", subscriptions) +app.router.add_get("/graphiql", graphiql_view) +app.router.add_get("/graphql", graphql_view) +app.router.add_post("/graphql", graphql_view) web.run_app(app, port=8000) diff --git a/examples/aiohttp/schema.py b/examples/aiohttp/schema.py index 3c23d00..ae107c7 100644 --- a/examples/aiohttp/schema.py +++ b/examples/aiohttp/schema.py @@ -20,14 +20,14 @@ async def resolve_count_seconds(root, info, up_to=5): for i in range(up_to): print("YIELD SECOND", i) yield i - await asyncio.sleep(1.) + await asyncio.sleep(1.0) yield up_to async def resolve_random_int(root, info): i = 0 while True: yield RandomType(seconds=i, random_int=random.randint(0, 500)) - await asyncio.sleep(1.) + await asyncio.sleep(1.0) i += 1 diff --git a/examples/aiohttp/template.py b/examples/aiohttp/template.py index 0b74e96..709f7cf 100644 --- a/examples/aiohttp/template.py +++ b/examples/aiohttp/template.py @@ -1,9 +1,9 @@ - from string import Template def render_graphiql(): - return Template(''' + return Template( + """ @@ -116,10 +116,11 @@ def render_graphiql(): ); -''').substitute( - GRAPHIQL_VERSION='0.10.2', - SUBSCRIPTIONS_TRANSPORT_VERSION='0.7.0', - subscriptionsEndpoint='ws://localhost:8000/subscriptions', +""" + ).substitute( + GRAPHIQL_VERSION="0.10.2", + SUBSCRIPTIONS_TRANSPORT_VERSION="0.7.0", + subscriptionsEndpoint="ws://localhost:8000/subscriptions", # subscriptionsEndpoint='ws://localhost:5000/', - endpointURL='/graphql', + endpointURL="/graphql", ) diff --git a/examples/django_subscriptions/django_subscriptions/asgi.py b/examples/django_subscriptions/django_subscriptions/asgi.py index e6edd7d..35b4d4d 100644 --- a/examples/django_subscriptions/django_subscriptions/asgi.py +++ b/examples/django_subscriptions/django_subscriptions/asgi.py @@ -3,4 +3,4 @@ os.environ.setdefault("DJANGO_SETTINGS_MODULE", "django_subscriptions.settings") -channel_layer = get_channel_layer() \ No newline at end of file +channel_layer = get_channel_layer() diff --git a/examples/django_subscriptions/django_subscriptions/schema.py b/examples/django_subscriptions/django_subscriptions/schema.py index b55d76e..db6893c 100644 --- a/examples/django_subscriptions/django_subscriptions/schema.py +++ b/examples/django_subscriptions/django_subscriptions/schema.py @@ -6,18 +6,19 @@ class Query(graphene.ObjectType): hello = graphene.String() def resolve_hello(self, info, **kwargs): - return 'world' + return "world" + class Subscription(graphene.ObjectType): count_seconds = graphene.Int(up_to=graphene.Int()) - def resolve_count_seconds(root, info, up_to=5): - return Observable.interval(1000)\ - .map(lambda i: "{0}".format(i))\ - .take_while(lambda i: int(i) <= up_to) - + return ( + Observable.interval(1000) + .map(lambda i: "{0}".format(i)) + .take_while(lambda i: int(i) <= up_to) + ) -schema = graphene.Schema(query=Query, subscription=Subscription) \ No newline at end of file +schema = graphene.Schema(query=Query, subscription=Subscription) diff --git a/examples/django_subscriptions/django_subscriptions/settings.py b/examples/django_subscriptions/django_subscriptions/settings.py index 45d0471..62cac69 100644 --- a/examples/django_subscriptions/django_subscriptions/settings.py +++ b/examples/django_subscriptions/django_subscriptions/settings.py @@ -20,7 +20,7 @@ # See https://docs.djangoproject.com/en/1.11/howto/deployment/checklist/ # SECURITY WARNING: keep the secret key used in production secret! -SECRET_KEY = 'fa#kz8m$l6)4(np9+-j_-z!voa090mah!s9^4jp=kj!^nwdq^c' +SECRET_KEY = "fa#kz8m$l6)4(np9+-j_-z!voa090mah!s9^4jp=kj!^nwdq^c" # SECURITY WARNING: don't run with debug turned on in production! DEBUG = True @@ -31,53 +31,53 @@ # Application definition INSTALLED_APPS = [ - 'django.contrib.admin', - 'django.contrib.auth', - 'django.contrib.contenttypes', - 'django.contrib.sessions', - 'django.contrib.messages', - 'django.contrib.staticfiles', - 'channels', + "django.contrib.admin", + "django.contrib.auth", + "django.contrib.contenttypes", + "django.contrib.sessions", + "django.contrib.messages", + "django.contrib.staticfiles", + "channels", ] MIDDLEWARE = [ - 'django.middleware.security.SecurityMiddleware', - 'django.contrib.sessions.middleware.SessionMiddleware', - 'django.middleware.common.CommonMiddleware', - 'django.middleware.csrf.CsrfViewMiddleware', - 'django.contrib.auth.middleware.AuthenticationMiddleware', - 'django.contrib.messages.middleware.MessageMiddleware', - 'django.middleware.clickjacking.XFrameOptionsMiddleware', + "django.middleware.security.SecurityMiddleware", + "django.contrib.sessions.middleware.SessionMiddleware", + "django.middleware.common.CommonMiddleware", + "django.middleware.csrf.CsrfViewMiddleware", + "django.contrib.auth.middleware.AuthenticationMiddleware", + "django.contrib.messages.middleware.MessageMiddleware", + "django.middleware.clickjacking.XFrameOptionsMiddleware", ] -ROOT_URLCONF = 'django_subscriptions.urls' +ROOT_URLCONF = "django_subscriptions.urls" TEMPLATES = [ { - 'BACKEND': 'django.template.backends.django.DjangoTemplates', - 'DIRS': [], - 'APP_DIRS': True, - 'OPTIONS': { - 'context_processors': [ - 'django.template.context_processors.debug', - 'django.template.context_processors.request', - 'django.contrib.auth.context_processors.auth', - 'django.contrib.messages.context_processors.messages', + "BACKEND": "django.template.backends.django.DjangoTemplates", + "DIRS": [], + "APP_DIRS": True, + "OPTIONS": { + "context_processors": [ + "django.template.context_processors.debug", + "django.template.context_processors.request", + "django.contrib.auth.context_processors.auth", + "django.contrib.messages.context_processors.messages", ], }, }, ] -WSGI_APPLICATION = 'django_subscriptions.wsgi.application' +WSGI_APPLICATION = "django_subscriptions.wsgi.application" # Database # https://docs.djangoproject.com/en/1.11/ref/settings/#databases DATABASES = { - 'default': { - 'ENGINE': 'django.db.backends.sqlite3', - 'NAME': os.path.join(BASE_DIR, 'db.sqlite3'), + "default": { + "ENGINE": "django.db.backends.sqlite3", + "NAME": os.path.join(BASE_DIR, "db.sqlite3"), } } @@ -87,26 +87,20 @@ AUTH_PASSWORD_VALIDATORS = [ { - 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator', - }, - { - 'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator', - }, - { - 'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator', - }, - { - 'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator', + "NAME": "django.contrib.auth.password_validation.UserAttributeSimilarityValidator", }, + {"NAME": "django.contrib.auth.password_validation.MinimumLengthValidator"}, + {"NAME": "django.contrib.auth.password_validation.CommonPasswordValidator"}, + {"NAME": "django.contrib.auth.password_validation.NumericPasswordValidator"}, ] # Internationalization # https://docs.djangoproject.com/en/1.11/topics/i18n/ -LANGUAGE_CODE = 'en-us' +LANGUAGE_CODE = "en-us" -TIME_ZONE = 'UTC' +TIME_ZONE = "UTC" USE_I18N = True @@ -118,20 +112,17 @@ # Static files (CSS, JavaScript, Images) # https://docs.djangoproject.com/en/1.11/howto/static-files/ -STATIC_URL = '/static/' -CHANNELS_WS_PROTOCOLS = ["graphql-ws", ] +STATIC_URL = "/static/" +CHANNELS_WS_PROTOCOLS = [ + "graphql-ws", +] CHANNEL_LAYERS = { "default": { "BACKEND": "asgi_redis.RedisChannelLayer", - "CONFIG": { - "hosts": [("localhost", 6379)], - }, + "CONFIG": {"hosts": [("localhost", 6379)]}, "ROUTING": "django_subscriptions.urls.channel_routing", }, - } -GRAPHENE = { - 'SCHEMA': 'django_subscriptions.schema.schema' -} \ No newline at end of file +GRAPHENE = {"SCHEMA": "django_subscriptions.schema.schema"} diff --git a/examples/django_subscriptions/django_subscriptions/template.py b/examples/django_subscriptions/django_subscriptions/template.py index b067ae5..738d9e7 100644 --- a/examples/django_subscriptions/django_subscriptions/template.py +++ b/examples/django_subscriptions/django_subscriptions/template.py @@ -1,9 +1,9 @@ - from string import Template def render_graphiql(): - return Template(''' + return Template( + """ @@ -116,10 +116,11 @@ def render_graphiql(): ); -''').substitute( - GRAPHIQL_VERSION='0.11.10', - SUBSCRIPTIONS_TRANSPORT_VERSION='0.7.0', - subscriptionsEndpoint='ws://localhost:8000/subscriptions', +""" + ).substitute( + GRAPHIQL_VERSION="0.11.10", + SUBSCRIPTIONS_TRANSPORT_VERSION="0.7.0", + subscriptionsEndpoint="ws://localhost:8000/subscriptions", # subscriptionsEndpoint='ws://localhost:5000/', - endpointURL='/graphql', + endpointURL="/graphql", ) diff --git a/examples/django_subscriptions/django_subscriptions/urls.py b/examples/django_subscriptions/django_subscriptions/urls.py index 3848d22..caf790d 100644 --- a/examples/django_subscriptions/django_subscriptions/urls.py +++ b/examples/django_subscriptions/django_subscriptions/urls.py @@ -21,20 +21,21 @@ from graphene_django.views import GraphQLView from django.views.decorators.csrf import csrf_exempt +from channels.routing import route_class +from graphql_ws.django_channels import GraphQLSubscriptionConsumer + def graphiql(request): response = HttpResponse(content=render_graphiql()) return response + urlpatterns = [ - url(r'^admin/', admin.site.urls), - url(r'^graphiql/', graphiql), - url(r'^graphql', csrf_exempt(GraphQLView.as_view(graphiql=True))) + url(r"^admin/", admin.site.urls), + url(r"^graphiql/", graphiql), + url(r"^graphql", csrf_exempt(GraphQLView.as_view(graphiql=True))), ] -from channels.routing import route_class -from graphql_ws.django_channels import GraphQLSubscriptionConsumer - channel_routing = [ route_class(GraphQLSubscriptionConsumer, path=r"^/subscriptions"), -] \ No newline at end of file +] diff --git a/examples/flask_gevent/app.py b/examples/flask_gevent/app.py index dbb0cca..efd145b 100644 --- a/examples/flask_gevent/app.py +++ b/examples/flask_gevent/app.py @@ -1,5 +1,3 @@ -import json - from flask import Flask, make_response from flask_graphql import GraphQLView from flask_sockets import Sockets @@ -14,19 +12,20 @@ sockets = Sockets(app) -@app.route('/graphiql') +@app.route("/graphiql") def graphql_view(): return make_response(render_graphiql()) app.add_url_rule( - '/graphql', view_func=GraphQLView.as_view('graphql', schema=schema, graphiql=False)) + "/graphql", view_func=GraphQLView.as_view("graphql", schema=schema, graphiql=False) +) subscription_server = GeventSubscriptionServer(schema) -app.app_protocol = lambda environ_path_info: 'graphql-ws' +app.app_protocol = lambda environ_path_info: "graphql-ws" -@sockets.route('/subscriptions') +@sockets.route("/subscriptions") def echo_socket(ws): subscription_server.handle(ws) return [] @@ -35,5 +34,6 @@ def echo_socket(ws): if __name__ == "__main__": from gevent import pywsgi from geventwebsocket.handler import WebSocketHandler - server = pywsgi.WSGIServer(('', 5000), app, handler_class=WebSocketHandler) + + server = pywsgi.WSGIServer(("", 5000), app, handler_class=WebSocketHandler) server.serve_forever() diff --git a/examples/flask_gevent/schema.py b/examples/flask_gevent/schema.py index 6e6298c..eb48050 100644 --- a/examples/flask_gevent/schema.py +++ b/examples/flask_gevent/schema.py @@ -19,12 +19,16 @@ class Subscription(graphene.ObjectType): random_int = graphene.Field(RandomType) def resolve_count_seconds(root, info, up_to=5): - return Observable.interval(1000)\ - .map(lambda i: "{0}".format(i))\ - .take_while(lambda i: int(i) <= up_to) + return ( + Observable.interval(1000) + .map(lambda i: "{0}".format(i)) + .take_while(lambda i: int(i) <= up_to) + ) def resolve_random_int(root, info): - return Observable.interval(1000).map(lambda i: RandomType(seconds=i, random_int=random.randint(0, 500))) + return Observable.interval(1000).map( + lambda i: RandomType(seconds=i, random_int=random.randint(0, 500)) + ) schema = graphene.Schema(query=Query, subscription=Subscription) diff --git a/examples/flask_gevent/template.py b/examples/flask_gevent/template.py index 41f52e1..ea0438c 100644 --- a/examples/flask_gevent/template.py +++ b/examples/flask_gevent/template.py @@ -1,9 +1,9 @@ - from string import Template def render_graphiql(): - return Template(''' + return Template( + """ @@ -116,10 +116,11 @@ def render_graphiql(): ); -''').substitute( - GRAPHIQL_VERSION='0.12.0', - SUBSCRIPTIONS_TRANSPORT_VERSION='0.7.0', - subscriptionsEndpoint='ws://localhost:5000/subscriptions', +""" + ).substitute( + GRAPHIQL_VERSION="0.12.0", + SUBSCRIPTIONS_TRANSPORT_VERSION="0.7.0", + subscriptionsEndpoint="ws://localhost:5000/subscriptions", # subscriptionsEndpoint='ws://localhost:5000/', - endpointURL='/graphql', + endpointURL="/graphql", ) diff --git a/examples/websockets_lib/app.py b/examples/websockets_lib/app.py index 0de6988..7638f3d 100644 --- a/examples/websockets_lib/app.py +++ b/examples/websockets_lib/app.py @@ -8,21 +8,23 @@ app = Sanic(__name__) -@app.listener('before_server_start') +@app.listener("before_server_start") def init_graphql(app, loop): - app.add_route(GraphQLView.as_view(schema=schema, - executor=AsyncioExecutor(loop=loop)), - '/graphql') + app.add_route( + GraphQLView.as_view(schema=schema, executor=AsyncioExecutor(loop=loop)), + "/graphql", + ) -@app.route('/graphiql') +@app.route("/graphiql") async def graphiql_view(request): return response.html(render_graphiql()) + subscription_server = WsLibSubscriptionServer(schema) -@app.websocket('/subscriptions', subprotocols=['graphql-ws']) +@app.websocket("/subscriptions", subprotocols=["graphql-ws"]) async def subscriptions(request, ws): await subscription_server.handle(ws) return ws diff --git a/examples/websockets_lib/schema.py b/examples/websockets_lib/schema.py index 3c23d00..ae107c7 100644 --- a/examples/websockets_lib/schema.py +++ b/examples/websockets_lib/schema.py @@ -20,14 +20,14 @@ async def resolve_count_seconds(root, info, up_to=5): for i in range(up_to): print("YIELD SECOND", i) yield i - await asyncio.sleep(1.) + await asyncio.sleep(1.0) yield up_to async def resolve_random_int(root, info): i = 0 while True: yield RandomType(seconds=i, random_int=random.randint(0, 500)) - await asyncio.sleep(1.) + await asyncio.sleep(1.0) i += 1 diff --git a/examples/websockets_lib/template.py b/examples/websockets_lib/template.py index 03587bb..8f007b9 100644 --- a/examples/websockets_lib/template.py +++ b/examples/websockets_lib/template.py @@ -1,9 +1,9 @@ - from string import Template def render_graphiql(): - return Template(''' + return Template( + """ @@ -116,9 +116,10 @@ def render_graphiql(): ); -''').substitute( - GRAPHIQL_VERSION='0.10.2', - SUBSCRIPTIONS_TRANSPORT_VERSION='0.7.0', - subscriptionsEndpoint='ws://localhost:8000/subscriptions', - endpointURL='/graphql', +""" + ).substitute( + GRAPHIQL_VERSION="0.10.2", + SUBSCRIPTIONS_TRANSPORT_VERSION="0.7.0", + subscriptionsEndpoint="ws://localhost:8000/subscriptions", + endpointURL="/graphql", ) From 9499ae9154f682d055cfebe9c7f7cbc9e4359e3e Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Fri, 26 Jun 2020 09:59:21 +1200 Subject: [PATCH 10/42] Black the modules in graphql_ws root --- graphql_ws/__init__.py | 4 ++-- graphql_ws/base.py | 4 ++-- graphql_ws/constants.py | 22 +++++++++++----------- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/graphql_ws/__init__.py b/graphql_ws/__init__.py index 793831a..0ffa258 100644 --- a/graphql_ws/__init__.py +++ b/graphql_ws/__init__.py @@ -3,5 +3,5 @@ """Top-level package for GraphQL WS.""" __author__ = """Syrus Akbary""" -__email__ = 'me@syrusakbary.com' -__version__ = '0.3.1' +__email__ = "me@syrusakbary.com" +__version__ = "0.3.1" diff --git a/graphql_ws/base.py b/graphql_ws/base.py index 9fb931d..0a2577e 100644 --- a/graphql_ws/base.py +++ b/graphql_ws/base.py @@ -99,8 +99,8 @@ def on_connection_terminate(self, connection_context, op_id): return connection_context.close(1011) def get_graphql_params(self, connection_context, payload): - context = payload.get('context') or {} - context.setdefault('request_context', connection_context.request_context) + context = payload.get("context") or {} + context.setdefault("request_context", connection_context.request_context) return { "request_string": payload.get("query"), "variable_values": payload.get("variables"), diff --git a/graphql_ws/constants.py b/graphql_ws/constants.py index 4f9d2f1..8b57a60 100644 --- a/graphql_ws/constants.py +++ b/graphql_ws/constants.py @@ -1,15 +1,15 @@ -GRAPHQL_WS = 'graphql-ws' +GRAPHQL_WS = "graphql-ws" WS_PROTOCOL = GRAPHQL_WS -GQL_CONNECTION_INIT = 'connection_init' # Client -> Server -GQL_CONNECTION_ACK = 'connection_ack' # Server -> Client -GQL_CONNECTION_ERROR = 'connection_error' # Server -> Client +GQL_CONNECTION_INIT = "connection_init" # Client -> Server +GQL_CONNECTION_ACK = "connection_ack" # Server -> Client +GQL_CONNECTION_ERROR = "connection_error" # Server -> Client # NOTE: This one here don't follow the standard due to connection optimization -GQL_CONNECTION_TERMINATE = 'connection_terminate' # Client -> Server -GQL_CONNECTION_KEEP_ALIVE = 'ka' # Server -> Client -GQL_START = 'start' # Client -> Server -GQL_DATA = 'data' # Server -> Client -GQL_ERROR = 'error' # Server -> Client -GQL_COMPLETE = 'complete' # Server -> Client -GQL_STOP = 'stop' # Client -> Server +GQL_CONNECTION_TERMINATE = "connection_terminate" # Client -> Server +GQL_CONNECTION_KEEP_ALIVE = "ka" # Server -> Client +GQL_START = "start" # Client -> Server +GQL_DATA = "data" # Server -> Client +GQL_ERROR = "error" # Server -> Client +GQL_COMPLETE = "complete" # Server -> Client +GQL_STOP = "stop" # Client -> Server From 738f447cf1d71191c8cd8a6f2865a54dd05cbe72 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Mon, 29 Jun 2020 09:34:11 +1200 Subject: [PATCH 11/42] Skip flake8 false positives and remove unneeded import --- tests/test_django_channels.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_django_channels.py b/tests/test_django_channels.py index 137d541..0552c7b 100644 --- a/tests/test_django_channels.py +++ b/tests/test_django_channels.py @@ -5,7 +5,7 @@ import django import mock from channels import Channel -from channels.test import ChannelTestCase, Client +from channels.test import ChannelTestCase from django.conf import settings from django.core.management import call_command @@ -25,8 +25,8 @@ ) django.setup() -from graphql_ws.constants import GQL_CONNECTION_ACK, GQL_CONNECTION_INIT -from graphql_ws.django_channels import ( +from graphql_ws.constants import GQL_CONNECTION_ACK, GQL_CONNECTION_INIT # noqa: E402 +from graphql_ws.django_channels import ( # noqa: E402 DjangoChannelConnectionContext, DjangoChannelSubscriptionServer, GraphQLSubscriptionConsumer, From f641e584f0f32c2733ec986c033634b974c85f78 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Mon, 29 Jun 2020 10:03:30 +1200 Subject: [PATCH 12/42] Update contributing doc --- CONTRIBUTING.rst | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index 01d606e..a2315ad 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -68,7 +68,7 @@ Ready to contribute? Here's how to set up `graphql_ws` for local development. $ mkvirtualenv graphql_ws $ cd graphql_ws/ - $ python setup.py develop + $ pip install -e .[dev] 4. Create a branch for local development:: @@ -79,11 +79,8 @@ Ready to contribute? Here's how to set up `graphql_ws` for local development. 5. When you're done making changes, check that your changes pass flake8 and the tests, including testing other Python versions with tox:: $ flake8 graphql_ws tests - $ python setup.py test or py.test $ tox - To get flake8 and tox, just pip install them into your virtualenv. - 6. Commit your changes and push your branch to GitHub:: $ git add . @@ -101,14 +98,6 @@ Before you submit a pull request, check that it meets these guidelines: 2. If the pull request adds functionality, the docs should be updated. Put your new functionality into a function with a docstring, and add the feature to the list in README.rst. -3. The pull request should work for Python 2.6, 2.7, 3.3, 3.4 and 3.5, and for PyPy. Check +3. The pull request should work for Python 2.7, 3.5, 3.6, 3.7 and 3.8. Check https://travis-ci.org/graphql-python/graphql_ws/pull_requests and make sure that the tests pass for all supported Python versions. - -Tips ----- - -To run a subset of tests:: - -$ py.test tests.test_graphql_ws - From c007f58e2e3d5082aedf6a9ec1fac612dc9edf2d Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Mon, 29 Jun 2020 10:03:42 +1200 Subject: [PATCH 13/42] Correctly test a bad graphql parameter --- tests/test_graphql_ws.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_graphql_ws.py b/tests/test_graphql_ws.py index 65cbf91..4a7b845 100644 --- a/tests/test_graphql_ws.py +++ b/tests/test_graphql_ws.py @@ -111,7 +111,7 @@ def test_start_bad_graphql_params(self, ss, cc): ss.unsubscribe = mock.Mock() ss.on_start = mock.Mock() ss.process_message( - cc, {"id": "1", "type": constants.GQL_START, "payload": {"a": "b"}} + cc, {"id": "1", "type": None, "payload": {"a": "b"}} ) assert ss.send_error.called assert ss.send_error.call_args[0][:2] == (cc, "1") @@ -136,7 +136,7 @@ def test_get_graphql_params(ss, cc): "query": "req", "variables": "vars", "operationName": "query", - "context": "ctx", + "context": {}, } params = ss.get_graphql_params(cc, payload) assert isinstance(params.pop("executor"), SyncExecutor) @@ -144,7 +144,7 @@ def test_get_graphql_params(ss, cc): "request_string": "req", "variable_values": "vars", "operation_name": "query", - "context_value": "ctx", + "context_value": {'request_context': None}, } From bb4f1be10587b08aa85630e72018fd350410e12f Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Mon, 29 Jun 2020 11:24:38 +1200 Subject: [PATCH 14/42] Make removing an operation from context fail silently --- graphql_ws/base.py | 5 ++++- tests/test_base.py | 12 ++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/graphql_ws/base.py b/graphql_ws/base.py index 0a2577e..db4f675 100644 --- a/graphql_ws/base.py +++ b/graphql_ws/base.py @@ -34,7 +34,10 @@ def get_operation(self, op_id): return self.operations[op_id] def remove_operation(self, op_id): - del self.operations[op_id] + try: + del self.operations[op_id] + except KeyError: + pass def receive(self): raise NotImplementedError("receive method not implemented") diff --git a/tests/test_base.py b/tests/test_base.py index 2e78459..80de021 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -60,3 +60,15 @@ def test_message_invalid(): server.send_error = mock.Mock() server.on_message(connection_context=None, message="'not-json") assert server.send_error.called + + +def test_context_operations(): + ws = mock.Mock() + context = base.BaseConnectionContext(ws) + assert not context.has_operation(1) + context.register_operation(1, None) + assert context.has_operation(1) + context.remove_operation(1) + assert not context.has_operation(1) + # Removing a non-existant operation fails silently. + context.remove_operation(999) From a4eef790606d7607724ae156fbdd1f7126e17126 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Mon, 29 Jun 2020 11:30:42 +1200 Subject: [PATCH 15/42] Make async methods send an error if an operation raises an exception Also remove iteratable operations from the context when they complete --- graphql_ws/base_async.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/graphql_ws/base_async.py b/graphql_ws/base_async.py index 95d2f2b..29dfb08 100644 --- a/graphql_ws/base_async.py +++ b/graphql_ws/base_async.py @@ -87,16 +87,23 @@ async def on_start(self, connection_context, op_id, params): if hasattr(execution_result, "__aiter__"): iterator = await execution_result.__aiter__() connection_context.register_operation(op_id, iterator) - async for single_result in iterator: - if not connection_context.has_operation(op_id): - break + try: + async for single_result in iterator: + if not connection_context.has_operation(op_id): + break + await self.send_execution_result( + connection_context, op_id, single_result + ) + except Exception as e: + await self.send_error(connection_context, op_id, e) + connection_context.remove_operation(op_id) + else: + try: await self.send_execution_result( - connection_context, op_id, single_result + connection_context, op_id, execution_result ) - else: - await self.send_execution_result( - connection_context, op_id, execution_result - ) + except Exception as e: + await self.send_error(connection_context, op_id, e) await self.send_message(connection_context, op_id, GQL_COMPLETE) await self.on_operation_complete(connection_context, op_id) From 3e670e60cd7f0c8e55248dc7cbca09e12ed3be84 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Mon, 29 Jun 2020 11:31:13 +1200 Subject: [PATCH 16/42] Send completion messages when the sync observer completes / errors out. --- graphql_ws/base_sync.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/graphql_ws/base_sync.py b/graphql_ws/base_sync.py index 56b4d42..0f15c01 100644 --- a/graphql_ws/base_sync.py +++ b/graphql_ws/base_sync.py @@ -2,7 +2,7 @@ from rx import Observable, Observer from .base import BaseSubscriptionServer -from .constants import GQL_CONNECTION_ACK, GQL_CONNECTION_ERROR +from .constants import GQL_COMPLETE, GQL_CONNECTION_ACK, GQL_CONNECTION_ERROR class BaseSyncSubscriptionServer(BaseSubscriptionServer): @@ -49,30 +49,33 @@ def on_start(self, connection_context, op_id, params): op_id, self.send_execution_result, self.send_error, - self.on_close, + self.send_message, ) ) connection_context.register_operation(op_id, disposable) except Exception as e: - self.send_error(connection_context, op_id, str(e)) + self.send_error(connection_context, op_id, e) + self.send_message(connection_context, op_id, GQL_COMPLETE) class SubscriptionObserver(Observer): def __init__( - self, connection_context, op_id, send_execution_result, send_error, on_close + self, connection_context, op_id, send_execution_result, send_error, send_message ): self.connection_context = connection_context self.op_id = op_id self.send_execution_result = send_execution_result self.send_error = send_error - self.on_close = on_close + self.send_message = send_message def on_next(self, value): self.send_execution_result(self.connection_context, self.op_id, value) def on_completed(self): - self.on_close(self.connection_context) + self.send_message(self.connection_context, self.op_id, GQL_COMPLETE) + self.connection_context.remove_operation(self.op_id) def on_error(self, error): self.send_error(self.connection_context, self.op_id, error) + self.on_completed() From 650db340831f35b21d5c4402edf0533391656e01 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Mon, 29 Jun 2020 11:46:35 +1200 Subject: [PATCH 17/42] Cody tidy --- graphql_ws/base_sync.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphql_ws/base_sync.py b/graphql_ws/base_sync.py index 0f15c01..a6d2efb 100644 --- a/graphql_ws/base_sync.py +++ b/graphql_ws/base_sync.py @@ -21,7 +21,7 @@ def on_connect(self, connection_context, payload): pass def on_close(self, connection_context): - remove_operations = list(connection_context.operations.keys()) + remove_operations = list(connection_context.operations) for op_id in remove_operations: self.unsubscribe(connection_context, op_id) From a8c2f33bea6fdc7ca134abfa0a1a34ac2fe94319 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Mon, 29 Jun 2020 12:29:33 +1200 Subject: [PATCH 18/42] Abstract ensuring async task is a future --- graphql_ws/aiohttp.py | 6 ++---- graphql_ws/base_async.py | 4 ++-- graphql_ws/websockets_lib.py | 6 ++---- 3 files changed, 6 insertions(+), 10 deletions(-) diff --git a/graphql_ws/aiohttp.py b/graphql_ws/aiohttp.py index 49e0a5e..d2162f2 100644 --- a/graphql_ws/aiohttp.py +++ b/graphql_ws/aiohttp.py @@ -1,5 +1,5 @@ import json -from asyncio import ensure_future, shield +from asyncio import shield from aiohttp import WSMsgType @@ -45,9 +45,7 @@ async def _handle(self, ws, request_context=None): break connection_context.remember_task( - ensure_future( - self.on_message(connection_context, message), loop=self.loop - ) + self.on_message(connection_context, message), loop=self.loop ) await self.on_close(connection_context) diff --git a/graphql_ws/base_async.py b/graphql_ws/base_async.py index 29dfb08..8cdf31d 100644 --- a/graphql_ws/base_async.py +++ b/graphql_ws/base_async.py @@ -35,8 +35,8 @@ def closed(self): async def close(self, code): ... - def remember_task(self, task): - self.pending_tasks.add(asyncio.ensure_future(task)) + def remember_task(self, task, loop=None): + self.pending_tasks.add(asyncio.ensure_future(task, loop=loop)) # Clear completed tasks self.pending_tasks -= WeakSet( task for task in self.pending_tasks if task.done() diff --git a/graphql_ws/websockets_lib.py b/graphql_ws/websockets_lib.py index 93ad76f..4d753a5 100644 --- a/graphql_ws/websockets_lib.py +++ b/graphql_ws/websockets_lib.py @@ -1,5 +1,5 @@ import json -from asyncio import ensure_future, shield +from asyncio import shield from websockets import ConnectionClosed @@ -41,9 +41,7 @@ async def _handle(self, ws, request_context): break connection_context.remember_task( - ensure_future( - self.on_message(connection_context, message), loop=self.loop - ) + self.on_message(connection_context, message), loop=self.loop ) await self.on_close(connection_context) From 8d32f4b67fde158c6f6f2284dbe9854a777b470e Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Mon, 29 Jun 2020 14:46:03 +1200 Subject: [PATCH 19/42] Tidy up django_channels (1) backend and example --- .../django_subscriptions/settings.py | 3 +- .../django_subscriptions/requirements.txt | 4 +++ graphql_ws/django_channels.py | 29 ++++++++++--------- 3 files changed, 20 insertions(+), 16 deletions(-) create mode 100644 examples/django_subscriptions/requirements.txt diff --git a/examples/django_subscriptions/django_subscriptions/settings.py b/examples/django_subscriptions/django_subscriptions/settings.py index 62cac69..7bb3f24 100644 --- a/examples/django_subscriptions/django_subscriptions/settings.py +++ b/examples/django_subscriptions/django_subscriptions/settings.py @@ -118,8 +118,7 @@ ] CHANNEL_LAYERS = { "default": { - "BACKEND": "asgi_redis.RedisChannelLayer", - "CONFIG": {"hosts": [("localhost", 6379)]}, + "BACKEND": "asgiref.inmemory.ChannelLayer", "ROUTING": "django_subscriptions.urls.channel_routing", }, } diff --git a/examples/django_subscriptions/requirements.txt b/examples/django_subscriptions/requirements.txt new file mode 100644 index 0000000..557e99f --- /dev/null +++ b/examples/django_subscriptions/requirements.txt @@ -0,0 +1,4 @@ +-e ../.. +django<2 +channels<2 +graphene_django<3 \ No newline at end of file diff --git a/graphql_ws/django_channels.py b/graphql_ws/django_channels.py index fbee47b..ddba58d 100644 --- a/graphql_ws/django_channels.py +++ b/graphql_ws/django_channels.py @@ -8,17 +8,18 @@ class DjangoChannelConnectionContext(BaseConnectionContext): - def __init__(self, message, request_context=None): - self.message = message - self.operations = {} - self.request_context = request_context + def __init__(self, message): + super(DjangoChannelConnectionContext, self).__init__( + message.reply_channel, + request_context={"user": message.user, "session": message.http_session}, + ) def send(self, data): - self.message.reply_channel.send({"text": json.dumps(data)}) + self.ws.send({"text": json.dumps(data)}) def close(self, reason): data = {"close": True, "text": reason} - self.message.reply_channel.send(data) + self.ws.send(data) class DjangoChannelSubscriptionServer(BaseSyncSubscriptionServer): @@ -26,21 +27,21 @@ def handle(self, message, connection_context): self.on_message(connection_context, message) +subscription_server = DjangoChannelSubscriptionServer(graphene_settings.SCHEMA) + + class GraphQLSubscriptionConsumer(JsonWebsocketConsumer): http_user_and_session = True strict_ordering = True - def connect(self, message, **_kwargs): + def connect(self, message, **kwargs): message.reply_channel.send({"accept": True}) - def receive(self, content, **_kwargs): + def receive(self, content, **kwargs): """ Called when a message is received with either text or bytes filled out. """ - self.connection_context = DjangoChannelConnectionContext(self.message) - self.subscription_server = DjangoChannelSubscriptionServer( - graphene_settings.SCHEMA - ) - self.subscription_server.on_open(self.connection_context) - self.subscription_server.handle(content, self.connection_context) + context = DjangoChannelConnectionContext(self.message) + subscription_server.on_open(context) + subscription_server.handle(content, context) From 7797c29d2ce155988c10a6c61f0b1ce30c866f31 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Mon, 29 Jun 2020 15:26:45 +1200 Subject: [PATCH 20/42] Update readme --- README.rst | 199 ++++++++++++++++++++++++++--------------------------- setup.cfg | 2 +- 2 files changed, 100 insertions(+), 101 deletions(-) diff --git a/README.rst b/README.rst index 90ee500..0a871f0 100644 --- a/README.rst +++ b/README.rst @@ -1,14 +1,23 @@ +========== GraphQL WS ========== -Websocket server for GraphQL subscriptions. +Websocket backend for GraphQL subscriptions. + +Supports the following application servers: + +Python 3 application servers, using asyncio: + + * `aiohttp`_ + * `websockets compatible servers`_ such as Sanic + (via `websockets `__ library) -Currently supports: +Python 2 application servers: + + * `Gevent compatible servers`_ such as Flask + * `Django v1.x`_ + (via `channels v1.x `__) -* `aiohttp `__ -* `Gevent `__ -* Sanic (uses `websockets `__ - library) Installation instructions ========================= @@ -19,21 +28,54 @@ For instaling graphql-ws, just run this command in your shell pip install graphql-ws + Examples --------- +======== + +Python 3 servers +---------------- + +Create a subscribable schema like this: + +.. code:: python + + import asyncio + import graphene + + + class Query(graphene.ObjectType): + hello = graphene.String() + + @static_method + def resolve_hello(obj, info, **kwargs): + return "world" + + + class Subscription(graphene.ObjectType): + count_seconds = graphene.Float(up_to=graphene.Int()) + + async def resolve_count_seconds(root, info, up_to): + for i in range(up_to): + yield i + await asyncio.sleep(1.) + yield up_to + + + schema = graphene.Schema(query=Query, subscription=Subscription) aiohttp ~~~~~~~ -For setting up, just plug into your aiohttp server. +Then just plug into your aiohttp server. .. code:: python from graphql_ws.aiohttp import AiohttpSubscriptionServer - + from .schema import schema subscription_server = AiohttpSubscriptionServer(schema) + async def subscriptions(request): ws = web.WebSocketResponse(protocols=('graphql-ws',)) await ws.prepare(request) @@ -47,21 +89,26 @@ For setting up, just plug into your aiohttp server. web.run_app(app, port=8000) -Sanic -~~~~~ +You can see a full example here: +https://github.com/graphql-python/graphql-ws/tree/master/examples/aiohttp + -Works with any framework that uses the websockets library for it’s -websocket implementation. For this example, plug in your Sanic server. +websockets compatible servers +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Works with any framework that uses the websockets library for its websocket +implementation. For this example, plug in your Sanic server. .. code:: python from graphql_ws.websockets_lib import WsLibSubscriptionServer - + from . import schema app = Sanic(__name__) subscription_server = WsLibSubscriptionServer(schema) + @app.websocket('/subscriptions', subprotocols=['graphql-ws']) async def subscriptions(request, ws): await subscription_server.handle(ws) @@ -70,80 +117,73 @@ websocket implementation. For this example, plug in your Sanic server. app.run(host="0.0.0.0", port=8000) -And then, plug into a subscribable schema: + +Python 2 servers +----------------- + +Create a subscribable schema like this: .. code:: python - import asyncio import graphene + from rx import Observable class Query(graphene.ObjectType): - base = graphene.String() + hello = graphene.String() + + @static_method + def resolve_hello(obj, info, **kwargs): + return "world" class Subscription(graphene.ObjectType): count_seconds = graphene.Float(up_to=graphene.Int()) - async def resolve_count_seconds(root, info, up_to): - for i in range(up_to): - yield i - await asyncio.sleep(1.) - yield up_to + async def resolve_count_seconds(root, info, up_to=5): + return Observable.interval(1000)\ + .map(lambda i: "{0}".format(i))\ + .take_while(lambda i: int(i) <= up_to) schema = graphene.Schema(query=Query, subscription=Subscription) -You can see a full example here: -https://github.com/graphql-python/graphql-ws/tree/master/examples/aiohttp - -Gevent -~~~~~~ +Gevent compatible servers +~~~~~~~~~~~~~~~~~~~~~~~~~ -For setting up, just plug into your Gevent server. +Then just plug into your Gevent server, for example, Flask: .. code:: python + from flask_sockets import Sockets + from graphql_ws.gevent import GeventSubscriptionServer + from schema import schema + subscription_server = GeventSubscriptionServer(schema) app.app_protocol = lambda environ_path_info: 'graphql-ws' + @sockets.route('/subscriptions') def echo_socket(ws): subscription_server.handle(ws) return [] -And then, plug into a subscribable schema: - -.. code:: python - - import graphene - from rx import Observable - - - class Query(graphene.ObjectType): - base = graphene.String() - - - class Subscription(graphene.ObjectType): - count_seconds = graphene.Float(up_to=graphene.Int()) - - async def resolve_count_seconds(root, info, up_to=5): - return Observable.interval(1000)\ - .map(lambda i: "{0}".format(i))\ - .take_while(lambda i: int(i) <= up_to) - - - schema = graphene.Schema(query=Query, subscription=Subscription) - You can see a full example here: https://github.com/graphql-python/graphql-ws/tree/master/examples/flask_gevent -Django Channels -~~~~~~~~~~~~~~~ +Django v1.x +~~~~~~~~~~~ -First ``pip install channels`` and it to your django apps +For Django v1.x and Django Channels v1.x, setup your schema in ``settings.py`` -Then add the following to your settings.py +.. code:: python + + GRAPHENE = { + 'SCHEMA': 'yourproject.schema.schema' + } + +Then ``pip install "channels<1"`` and it to your django apps, adding the +following to your ``settings.py`` .. code:: python @@ -153,53 +193,9 @@ Then add the following to your settings.py "BACKEND": "asgiref.inmemory.ChannelLayer", "ROUTING": "django_subscriptions.urls.channel_routing", }, - } -Setup your graphql schema - -.. code:: python - - import graphene - from rx import Observable - - - class Query(graphene.ObjectType): - hello = graphene.String() - - def resolve_hello(self, info, **kwargs): - return 'world' - - class Subscription(graphene.ObjectType): - - count_seconds = graphene.Int(up_to=graphene.Int()) - - - def resolve_count_seconds( - root, - info, - up_to=5 - ): - return Observable.interval(1000)\ - .map(lambda i: "{0}".format(i))\ - .take_while(lambda i: int(i) <= up_to) - - - - schema = graphene.Schema( - query=Query, - subscription=Subscription - ) - -Setup your schema in settings.py - -.. code:: python - - GRAPHENE = { - 'SCHEMA': 'path.to.schema' - } - -and finally add the channel routes +And finally add the channel routes .. code:: python @@ -209,3 +205,6 @@ and finally add the channel routes channel_routing = [ route_class(GraphQLSubscriptionConsumer, path=r"^/subscriptions"), ] + +You can see a full example here: +https://github.com/graphql-python/graphql-ws/tree/master/examples/django_subscriptions diff --git a/setup.cfg b/setup.cfg index b921bca..1e7ea2a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -2,7 +2,7 @@ [metadata] name = graphql-ws version = 0.3.1 -description = Websocket server for GraphQL subscriptions +description = Websocket backend for GraphQL subscriptions long_description = file: README.rst, CHANGES.rst author = Syrus Akbary author_email = me@syrusakbary.com From 5ed4f1d5f3a947ea5f707a0635677a6334865446 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Wed, 1 Jul 2020 09:28:16 +1200 Subject: [PATCH 21/42] Fix a readme typo --- README.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.rst b/README.rst index 0a871f0..fb968b6 100644 --- a/README.rst +++ b/README.rst @@ -46,7 +46,7 @@ Create a subscribable schema like this: class Query(graphene.ObjectType): hello = graphene.String() - @static_method + @staticmethod def resolve_hello(obj, info, **kwargs): return "world" @@ -132,7 +132,7 @@ Create a subscribable schema like this: class Query(graphene.ObjectType): hello = graphene.String() - @static_method + @staticmethod def resolve_hello(obj, info, **kwargs): return "world" From a3197d0b2bc0140ebd4c17469aaa730a72e6872b Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Wed, 1 Jul 2020 16:34:33 +1200 Subject: [PATCH 22/42] Recursively resolve Promises, fix async tests --- graphql_ws/base_async.py | 55 ++++++++++++++++++++- tests/test_base_async.py | 102 +++++++++++++++++++++++++++------------ 2 files changed, 124 insertions(+), 33 deletions(-) diff --git a/graphql_ws/base_async.py b/graphql_ws/base_async.py index 8cdf31d..af9e4e4 100644 --- a/graphql_ws/base_async.py +++ b/graphql_ws/base_async.py @@ -1,9 +1,12 @@ import asyncio +import inspect from abc import ABC, abstractmethod -from inspect import isawaitable +from types import CoroutineType, GeneratorType +from typing import Any, Union, List, Dict from weakref import WeakSet from graphql.execution.executors.asyncio import AsyncioExecutor +from promise import Promise from graphql_ws import base @@ -11,6 +14,49 @@ from .observable_aiter import setup_observable_extension setup_observable_extension() +CO_ITERABLE_COROUTINE = inspect.CO_ITERABLE_COROUTINE + + +# Copied from graphql-core v3.1.0 (graphql/pyutils/is_awaitable.py) +def is_awaitable(value: Any) -> bool: + """Return true if object can be passed to an ``await`` expression. + Instead of testing if the object is an instance of abc.Awaitable, it checks + the existence of an `__await__` attribute. This is much faster. + """ + return ( + # check for coroutine objects + isinstance(value, CoroutineType) + # check for old-style generator based coroutine objects + or isinstance(value, GeneratorType) + and bool(value.gi_code.co_flags & CO_ITERABLE_COROUTINE) + # check for other awaitables (e.g. futures) + or hasattr(value, "__await__") + ) + + +async def resolve( + data: Any, _container: Union[List, Dict] = None, _key: Union[str, int] = None +) -> None: + """ + Recursively wait on any awaitable children of a data element and resolve any + Promises. + """ + if is_awaitable(data): + data = await data + if isinstance(data, Promise): + data = data.value # type: Any + if _container is not None: + _container[_key] = data + if isinstance(data, dict): + items = data.items() + elif isinstance(data, list): + items = enumerate(data) + else: + items = None + if items is not None: + children = [resolve(child, _container=data, _key=key) for key, child in items] + if children: + await asyncio.wait(children) class BaseAsyncConnectionContext(base.BaseConnectionContext, ABC): @@ -81,7 +127,7 @@ async def on_connection_init(self, connection_context, op_id, payload): async def on_start(self, connection_context, op_id, params): execution_result = self.execute(params) - if isawaitable(execution_result): + if is_awaitable(execution_result): execution_result = await execution_result if hasattr(execution_result, "__aiter__"): @@ -120,3 +166,8 @@ async def on_stop(self, connection_context, op_id): async def on_operation_complete(self, connection_context, op_id): pass + + async def send_execution_result(self, connection_context, op_id, execution_result): + # Resolve any pending promises + await resolve(execution_result.data) + await super().send_execution_result(connection_context, op_id, execution_result) diff --git a/tests/test_base_async.py b/tests/test_base_async.py index 902acc7..d341c18 100644 --- a/tests/test_base_async.py +++ b/tests/test_base_async.py @@ -1,59 +1,99 @@ from unittest import mock import json +import promise import pytest -from graphql_ws import base +from graphql_ws import base, base_async +pytestmark = pytest.mark.asyncio -def test_not_implemented(): - server = base.BaseSubscriptionServer(schema=None) - with pytest.raises(NotImplementedError): - server.on_connection_init(connection_context=None, op_id=1, payload={}) - with pytest.raises(NotImplementedError): - server.on_open(connection_context=None) - with pytest.raises(NotImplementedError): - server.on_stop(connection_context=None, op_id=1) +class AsyncMock(mock.MagicMock): + async def __call__(self, *args, **kwargs): + return super().__call__(*args, **kwargs) -def test_terminate(): - server = base.BaseSubscriptionServer(schema=None) - context = mock.Mock() - server.on_connection_terminate(connection_context=context, op_id=1) +class TestServer(base_async.BaseAsyncSubscriptionServer): + def handle(self, *args, **kwargs): + pass + + +@pytest.fixture +def server(): + return TestServer(schema=None) + + +async def test_terminate(server: TestServer): + context = AsyncMock() + await server.on_connection_terminate(connection_context=context, op_id=1) context.close.assert_called_with(1011) -def test_send_error(): - server = base.BaseSubscriptionServer(schema=None) - context = mock.Mock() - server.send_error(connection_context=context, op_id=1, error="test error") +async def test_send_error(server: TestServer): + context = AsyncMock() + await server.send_error(connection_context=context, op_id=1, error="test error") context.send.assert_called_with( {"id": 1, "type": "error", "payload": {"message": "test error"}} ) -def test_message(): - server = base.BaseSubscriptionServer(schema=None) - server.process_message = mock.Mock() - context = mock.Mock() +async def test_message(server): + server.process_message = AsyncMock() + context = AsyncMock() msg = {"id": 1, "type": base.GQL_CONNECTION_INIT, "payload": ""} - server.on_message(context, msg) + await server.on_message(context, msg) server.process_message.assert_called_with(context, msg) -def test_message_str(): - server = base.BaseSubscriptionServer(schema=None) - server.process_message = mock.Mock() - context = mock.Mock() +async def test_message_str(server): + server.process_message = AsyncMock() + context = AsyncMock() msg = {"id": 1, "type": base.GQL_CONNECTION_INIT, "payload": ""} - server.on_message(context, json.dumps(msg)) + await server.on_message(context, json.dumps(msg)) server.process_message.assert_called_with(context, msg) -def test_message_invalid(): - server = base.BaseSubscriptionServer(schema=None) - server.send_error = mock.Mock() - server.on_message(connection_context=None, message="'not-json") +async def test_message_invalid(server): + server.send_error = AsyncMock() + await server.on_message(connection_context=None, message="'not-json") assert server.send_error.called + + +async def test_resolver(server): + server.send_message = AsyncMock() + result = mock.Mock() + result.data = {"test": [1, 2]} + result.errors = None + await server.send_execution_result( + connection_context=None, op_id=1, execution_result=result + ) + assert server.send_message.called + + +@pytest.mark.asyncio +async def test_resolver_with_promise(server): + server.send_message = AsyncMock() + result = mock.Mock() + result.data = {"test": [1, promise.Promise(lambda resolve, reject: resolve(2))]} + result.errors = None + await server.send_execution_result( + connection_context=None, op_id=1, execution_result=result + ) + assert server.send_message.called + assert result.data == {'test': [1, 2]} + + +async def test_resolver_with_nested_promise(server): + server.send_message = AsyncMock() + result = mock.Mock() + inner = promise.Promise(lambda resolve, reject: resolve(2)) + outer = promise.Promise(lambda resolve, reject: resolve({'in': inner})) + result.data = {"test": [1, outer]} + result.errors = None + await server.send_execution_result( + connection_context=None, op_id=1, execution_result=result + ) + assert server.send_message.called + assert result.data == {'test': [1, {'in': 2}]} From 84d5d1749ba69b9bc72a0d9100e697a5839e817c Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Wed, 29 Jul 2020 12:18:22 +1200 Subject: [PATCH 23/42] Ignore cancellederror when closing connections --- graphql_ws/base_async.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/graphql_ws/base_async.py b/graphql_ws/base_async.py index af9e4e4..c4353d7 100644 --- a/graphql_ws/base_async.py +++ b/graphql_ws/base_async.py @@ -159,7 +159,10 @@ async def on_close(self, connection_context): for op_id in connection_context.operations ) + tuple(task.cancel() for task in connection_context.pending_tasks) if awaitables: - await asyncio.gather(*awaitables, loop=self.loop) + try: + await asyncio.gather(*awaitables, loop=self.loop) + except asyncio.CancelledError: + pass async def on_stop(self, connection_context, op_id): await self.unsubscribe(connection_context, op_id) From 7bfc59094f94dc428752600402ef6d31aaa838bd Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Wed, 29 Jul 2020 15:41:46 +1200 Subject: [PATCH 24/42] Fix async processing messages --- graphql_ws/aiohttp.py | 4 +--- graphql_ws/base_async.py | 8 ++++---- graphql_ws/websockets_lib.py | 4 +--- 3 files changed, 6 insertions(+), 10 deletions(-) diff --git a/graphql_ws/aiohttp.py b/graphql_ws/aiohttp.py index d2162f2..baf8837 100644 --- a/graphql_ws/aiohttp.py +++ b/graphql_ws/aiohttp.py @@ -44,9 +44,7 @@ async def _handle(self, ws, request_context=None): except ConnectionClosedException: break - connection_context.remember_task( - self.on_message(connection_context, message), loop=self.loop - ) + self.on_message(connection_context, message) await self.on_close(connection_context) async def handle(self, ws, request_context=None): diff --git a/graphql_ws/base_async.py b/graphql_ws/base_async.py index c4353d7..d02cc29 100644 --- a/graphql_ws/base_async.py +++ b/graphql_ws/base_async.py @@ -81,8 +81,8 @@ def closed(self): async def close(self, code): ... - def remember_task(self, task, loop=None): - self.pending_tasks.add(asyncio.ensure_future(task, loop=loop)) + def remember_task(self, task): + self.pending_tasks.add(task) # Clear completed tasks self.pending_tasks -= WeakSet( task for task in self.pending_tasks if task.done() @@ -102,9 +102,9 @@ async def handle(self, ws, request_context=None): def process_message(self, connection_context, parsed_message): task = asyncio.ensure_future( - super().process_message(connection_context, parsed_message) + super().process_message(connection_context, parsed_message), loop=self.loop ) - connection_context.pending.add(task) + connection_context.remember_task(task) return task async def send_message(self, *args, **kwargs): diff --git a/graphql_ws/websockets_lib.py b/graphql_ws/websockets_lib.py index 4d753a5..c0adc67 100644 --- a/graphql_ws/websockets_lib.py +++ b/graphql_ws/websockets_lib.py @@ -40,9 +40,7 @@ async def _handle(self, ws, request_context): except ConnectionClosedException: break - connection_context.remember_task( - self.on_message(connection_context, message), loop=self.loop - ) + self.on_message(connection_context, message) await self.on_close(connection_context) async def handle(self, ws, request_context=None): From 583f3f0bced9edc3257904fe66f8d2609171fa44 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Wed, 29 Jul 2020 16:54:07 +1200 Subject: [PATCH 25/42] Fix async unsubscribe --- graphql_ws/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphql_ws/base.py b/graphql_ws/base.py index db4f675..798d19d 100644 --- a/graphql_ws/base.py +++ b/graphql_ws/base.py @@ -178,4 +178,4 @@ def unsubscribe(self, connection_context, op_id): connection_context.get_operation(op_id).dispose() # Close operation connection_context.remove_operation(op_id) - self.on_operation_complete(connection_context, op_id) + return self.on_operation_complete(connection_context, op_id) From de8ced3ab190d89237e08402193ff3b9baee63e2 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Thu, 30 Jul 2020 10:30:56 +1200 Subject: [PATCH 26/42] Move unsubscribe logic to the connection context --- graphql_ws/base.py | 32 +++++++++++++++----------------- graphql_ws/base_async.py | 32 ++++++++++++++++++-------------- graphql_ws/base_sync.py | 11 +++-------- tests/test_base.py | 9 +++++++-- tests/test_graphql_ws.py | 7 ++++--- 5 files changed, 47 insertions(+), 44 deletions(-) diff --git a/graphql_ws/base.py b/graphql_ws/base.py index 798d19d..35ee2fe 100644 --- a/graphql_ws/base.py +++ b/graphql_ws/base.py @@ -35,9 +35,18 @@ def get_operation(self, op_id): def remove_operation(self, op_id): try: - del self.operations[op_id] + return self.operations.pop(op_id) except KeyError: - pass + return + + def unsubscribe(self, op_id): + async_iterator = self.remove_operation(op_id) + if hasattr(async_iterator, 'dispose'): + async_iterator.dispose() + + def unsubscribe_all(self): + for op_id in list(self.operations): + self.unsubscribe(op_id) def receive(self): raise NotImplementedError("receive method not implemented") @@ -76,12 +85,6 @@ def process_message(self, connection_context, parsed_message): elif op_type == GQL_START: assert isinstance(payload, dict), "The payload must be a dict" - - # If we already have a subscription with this id, unsubscribe from - # it first - if connection_context.has_operation(op_id): - self.unsubscribe(connection_context, op_id) - params = self.get_graphql_params(connection_context, payload) return self.on_start(connection_context, op_id, params) @@ -116,7 +119,10 @@ def on_open(self, connection_context): raise NotImplementedError("on_open method not implemented") def on_stop(self, connection_context, op_id): - raise NotImplementedError("on_stop method not implemented") + return connection_context.unsubscribe(op_id) + + def on_close(self, connection_context): + return connection_context.unsubscribe_all() def send_message(self, connection_context, op_id=None, op_type=None, payload=None): message = self.build_message(op_id, op_type, payload) @@ -171,11 +177,3 @@ def on_message(self, connection_context, message): return self.send_error(connection_context, None, e) return self.process_message(connection_context, parsed_message) - - def unsubscribe(self, connection_context, op_id): - if connection_context.has_operation(op_id): - # Close async iterator - connection_context.get_operation(op_id).dispose() - # Close operation - connection_context.remove_operation(op_id) - return self.on_operation_complete(connection_context, op_id) diff --git a/graphql_ws/base_async.py b/graphql_ws/base_async.py index d02cc29..6cedc67 100644 --- a/graphql_ws/base_async.py +++ b/graphql_ws/base_async.py @@ -88,6 +88,20 @@ def remember_task(self, task): task for task in self.pending_tasks if task.done() ) + async def unsubscribe(self, op_id): + super().unsubscribe(op_id) + + async def unsubscribe_all(self): + awaitables = [self.unsubscribe(op_id) for op_id in list(self.operations)] + for task in self.pending_tasks: + task.cancel() + awaitables.append(task) + if awaitables: + try: + await asyncio.gather(*awaitables) + except asyncio.CancelledError: + pass + class BaseAsyncSubscriptionServer(base.BaseSubscriptionServer, ABC): graphql_executor = AsyncioExecutor @@ -125,6 +139,10 @@ async def on_connection_init(self, connection_context, op_id, payload): await connection_context.close(1011) async def on_start(self, connection_context, op_id, params): + # Attempt to unsubscribe first in case we already have a subscription + # with this id. + await connection_context.unsubscribe(op_id) + execution_result = self.execute(params) if is_awaitable(execution_result): @@ -153,20 +171,6 @@ async def on_start(self, connection_context, op_id, params): await self.send_message(connection_context, op_id, GQL_COMPLETE) await self.on_operation_complete(connection_context, op_id) - async def on_close(self, connection_context): - awaitables = tuple( - self.unsubscribe(connection_context, op_id) - for op_id in connection_context.operations - ) + tuple(task.cancel() for task in connection_context.pending_tasks) - if awaitables: - try: - await asyncio.gather(*awaitables, loop=self.loop) - except asyncio.CancelledError: - pass - - async def on_stop(self, connection_context, op_id): - await self.unsubscribe(connection_context, op_id) - async def on_operation_complete(self, connection_context, op_id): pass diff --git a/graphql_ws/base_sync.py b/graphql_ws/base_sync.py index a6d2efb..06db900 100644 --- a/graphql_ws/base_sync.py +++ b/graphql_ws/base_sync.py @@ -20,11 +20,6 @@ def on_open(self, connection_context): def on_connect(self, connection_context, payload): pass - def on_close(self, connection_context): - remove_operations = list(connection_context.operations) - for op_id in remove_operations: - self.unsubscribe(connection_context, op_id) - def on_connection_init(self, connection_context, op_id, payload): try: self.on_connect(connection_context, payload) @@ -34,10 +29,10 @@ def on_connection_init(self, connection_context, op_id, payload): self.send_error(connection_context, op_id, e, GQL_CONNECTION_ERROR) connection_context.close(1011) - def on_stop(self, connection_context, op_id): - self.unsubscribe(connection_context, op_id) - def on_start(self, connection_context, op_id, params): + # Attempt to unsubscribe first in case we already have a subscription + # with this id. + connection_context.unsubscribe(op_id) try: execution_result = self.execute(params) assert isinstance( diff --git a/tests/test_base.py b/tests/test_base.py index 80de021..5b40ac5 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -16,8 +16,13 @@ def test_not_implemented(): server.on_connection_init(connection_context=None, op_id=1, payload={}) with pytest.raises(NotImplementedError): server.on_open(connection_context=None) - with pytest.raises(NotImplementedError): - server.on_stop(connection_context=None, op_id=1) + + +def test_on_stop(): + server = base.BaseSubscriptionServer(schema=None) + context = mock.Mock() + server.on_stop(connection_context=context, op_id=1) + context.unsubscribe.assert_called_with(1) def test_terminate(): diff --git a/tests/test_graphql_ws.py b/tests/test_graphql_ws.py index 4a7b845..e29e2a2 100644 --- a/tests/test_graphql_ws.py +++ b/tests/test_graphql_ws.py @@ -94,12 +94,12 @@ def test_start_existing_op(self, ss, cc): ss.get_graphql_params.return_value = {"params": True} cc.has_operation = mock.Mock() cc.has_operation.return_value = True - ss.unsubscribe = mock.Mock() + cc.unsubscribe = mock.Mock() ss.on_start = mock.Mock() ss.process_message( cc, {"id": "1", "type": constants.GQL_START, "payload": {"a": "b"}} ) - assert ss.unsubscribe.called + assert cc.unsubscribe.called ss.on_start.assert_called_with(cc, "1", {"params": True}) def test_start_bad_graphql_params(self, ss, cc): @@ -162,7 +162,8 @@ def test_build_message_partial(ss): assert ss.build_message(id=None, op_type=None, payload="PAYLOAD") == { "payload": "PAYLOAD" } - assert ss.build_message(id=None, op_type=None, payload=None) == {} + with pytest.raises(AssertionError): + ss.build_message(id=None, op_type=None, payload=None) def test_send_execution_result(ss): From 9bec86e8a6d016d27ade83d7601ae75722005829 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Thu, 30 Jul 2020 10:31:51 +1200 Subject: [PATCH 27/42] Remove a redundant async method --- graphql_ws/base_async.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/graphql_ws/base_async.py b/graphql_ws/base_async.py index 6cedc67..0d57c42 100644 --- a/graphql_ws/base_async.py +++ b/graphql_ws/base_async.py @@ -121,9 +121,6 @@ def process_message(self, connection_context, parsed_message): connection_context.remember_task(task) return task - async def send_message(self, *args, **kwargs): - await super().send_message(*args, **kwargs) - async def on_open(self, connection_context): pass From a1d2ebc203f15e98bad234137c324b3b0f5d646c Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Thu, 30 Jul 2020 14:21:08 +1200 Subject: [PATCH 28/42] Only send messages for operations that exist --- graphql_ws/base.py | 7 ++++--- graphql_ws/base_async.py | 10 +++++++++- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/graphql_ws/base.py b/graphql_ws/base.py index 35ee2fe..1ed2da1 100644 --- a/graphql_ws/base.py +++ b/graphql_ws/base.py @@ -125,9 +125,9 @@ def on_close(self, connection_context): return connection_context.unsubscribe_all() def send_message(self, connection_context, op_id=None, op_type=None, payload=None): - message = self.build_message(op_id, op_type, payload) - assert message, "You need to send at least one thing" - return connection_context.send(message) + if op_id is None or connection_context.has_operation(op_id): + message = self.build_message(op_id, op_type, payload) + return connection_context.send(message) def build_message(self, id, op_type, payload): message = {} @@ -137,6 +137,7 @@ def build_message(self, id, op_type, payload): message["type"] = op_type if payload is not None: message["payload"] = payload + assert message, "You need to send at least one thing" return message def send_execution_result(self, connection_context, op_id, execution_result): diff --git a/graphql_ws/base_async.py b/graphql_ws/base_async.py index 0d57c42..735818d 100644 --- a/graphql_ws/base_async.py +++ b/graphql_ws/base_async.py @@ -142,6 +142,7 @@ async def on_start(self, connection_context, op_id, params): execution_result = self.execute(params) + connection_context.register_operation(op_id, execution_result) if is_awaitable(execution_result): execution_result = await execution_result @@ -157,7 +158,6 @@ async def on_start(self, connection_context, op_id, params): ) except Exception as e: await self.send_error(connection_context, op_id, e) - connection_context.remove_operation(op_id) else: try: await self.send_execution_result( @@ -166,8 +166,16 @@ async def on_start(self, connection_context, op_id, params): except Exception as e: await self.send_error(connection_context, op_id, e) await self.send_message(connection_context, op_id, GQL_COMPLETE) + connection_context.remove_operation(op_id) await self.on_operation_complete(connection_context, op_id) + async def send_message( + self, connection_context, op_id=None, op_type=None, payload=None + ): + if op_id is None or connection_context.has_operation(op_id): + message = self.build_message(op_id, op_type, payload) + return await connection_context.send(message) + async def on_operation_complete(self, connection_context, op_id): pass From 94d874027edceb8ae56b80907db07d25db705cca Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Thu, 30 Jul 2020 16:52:27 +1200 Subject: [PATCH 29/42] Iterators are considered awaitable with the new method, so check only not aiter --- graphql_ws/base_async.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/graphql_ws/base_async.py b/graphql_ws/base_async.py index 735818d..7f7e74f 100644 --- a/graphql_ws/base_async.py +++ b/graphql_ws/base_async.py @@ -143,9 +143,6 @@ async def on_start(self, connection_context, op_id, params): execution_result = self.execute(params) connection_context.register_operation(op_id, execution_result) - if is_awaitable(execution_result): - execution_result = await execution_result - if hasattr(execution_result, "__aiter__"): iterator = await execution_result.__aiter__() connection_context.register_operation(op_id, iterator) @@ -160,6 +157,8 @@ async def on_start(self, connection_context, op_id, params): await self.send_error(connection_context, op_id, e) else: try: + if is_awaitable(execution_result): + execution_result = await execution_result await self.send_execution_result( connection_context, op_id, execution_result ) From 218c7fc5e26ed671f1f56f9aed0548d9026ae637 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Thu, 30 Jul 2020 17:35:56 +1200 Subject: [PATCH 30/42] Add request context directly to the payload rather than a request_context key --- graphql_ws/base.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/graphql_ws/base.py b/graphql_ws/base.py index 1ed2da1..4df2fab 100644 --- a/graphql_ws/base.py +++ b/graphql_ws/base.py @@ -105,8 +105,7 @@ def on_connection_terminate(self, connection_context, op_id): return connection_context.close(1011) def get_graphql_params(self, connection_context, payload): - context = payload.get("context") or {} - context.setdefault("request_context", connection_context.request_context) + context = payload.get("context", connection_context.request_context) return { "request_string": payload.get("query"), "variable_values": payload.get("variables"), From ae0b0c7c9124a550c50db23781b0dc590beaec74 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Tue, 24 Nov 2020 16:54:44 +1300 Subject: [PATCH 31/42] Correctly unsubscribe after on_start operation is complete --- graphql_ws/base_async.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphql_ws/base_async.py b/graphql_ws/base_async.py index 7f7e74f..6954341 100644 --- a/graphql_ws/base_async.py +++ b/graphql_ws/base_async.py @@ -165,7 +165,7 @@ async def on_start(self, connection_context, op_id, params): except Exception as e: await self.send_error(connection_context, op_id, e) await self.send_message(connection_context, op_id, GQL_COMPLETE) - connection_context.remove_operation(op_id) + await connection_context.unsubscribe(op_id) await self.on_operation_complete(connection_context, op_id) async def send_message( From a964800472035943f1f965b1ca341d1cda147879 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Wed, 25 Nov 2020 00:23:24 +1300 Subject: [PATCH 32/42] Fix tests --- tests/test_graphql_ws.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/test_graphql_ws.py b/tests/test_graphql_ws.py index e29e2a2..3b85c49 100644 --- a/tests/test_graphql_ws.py +++ b/tests/test_graphql_ws.py @@ -1,4 +1,5 @@ from collections import OrderedDict + try: from unittest import mock except ImportError: @@ -95,12 +96,12 @@ def test_start_existing_op(self, ss, cc): cc.has_operation = mock.Mock() cc.has_operation.return_value = True cc.unsubscribe = mock.Mock() - ss.on_start = mock.Mock() + ss.execute = mock.Mock() + ss.send_message = mock.Mock() ss.process_message( cc, {"id": "1", "type": constants.GQL_START, "payload": {"a": "b"}} ) assert cc.unsubscribe.called - ss.on_start.assert_called_with(cc, "1", {"params": True}) def test_start_bad_graphql_params(self, ss, cc): ss.get_graphql_params = mock.Mock() @@ -110,9 +111,7 @@ def test_start_bad_graphql_params(self, ss, cc): ss.send_error = mock.Mock() ss.unsubscribe = mock.Mock() ss.on_start = mock.Mock() - ss.process_message( - cc, {"id": "1", "type": None, "payload": {"a": "b"}} - ) + ss.process_message(cc, {"id": "1", "type": None, "payload": {"a": "b"}}) assert ss.send_error.called assert ss.send_error.call_args[0][:2] == (cc, "1") assert isinstance(ss.send_error.call_args[0][2], Exception) @@ -144,7 +143,7 @@ def test_get_graphql_params(ss, cc): "request_string": "req", "variable_values": "vars", "operation_name": "query", - "context_value": {'request_context': None}, + "context_value": {}, } From cdbdda1744f949a656333dcbaa5fb8d9ed3442a6 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Mon, 29 Mar 2021 17:25:59 +1300 Subject: [PATCH 33/42] Async unsubscription needs to wait around for the future to cancel --- graphql_ws/base.py | 1 + graphql_ws/base_async.py | 9 +++++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/graphql_ws/base.py b/graphql_ws/base.py index 4df2fab..31ad657 100644 --- a/graphql_ws/base.py +++ b/graphql_ws/base.py @@ -43,6 +43,7 @@ def unsubscribe(self, op_id): async_iterator = self.remove_operation(op_id) if hasattr(async_iterator, 'dispose'): async_iterator.dispose() + return async_iterator def unsubscribe_all(self): for op_id in list(self.operations): diff --git a/graphql_ws/base_async.py b/graphql_ws/base_async.py index 6954341..0c62481 100644 --- a/graphql_ws/base_async.py +++ b/graphql_ws/base_async.py @@ -2,7 +2,7 @@ import inspect from abc import ABC, abstractmethod from types import CoroutineType, GeneratorType -from typing import Any, Union, List, Dict +from typing import Any, Dict, List, Union from weakref import WeakSet from graphql.execution.executors.asyncio import AsyncioExecutor @@ -89,7 +89,12 @@ def remember_task(self, task): ) async def unsubscribe(self, op_id): - super().unsubscribe(op_id) + async_iterator = super().unsubscribe(op_id) + if ( + getattr(async_iterator, "future", None) + and async_iterator.future.cancel() + ): + await async_iterator.future async def unsubscribe_all(self): awaitables = [self.unsubscribe(op_id) for op_id in list(self.operations)] From 45546366581b6c31b33b917cd5d805a758fe54a4 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Tue, 30 Mar 2021 09:12:12 +1300 Subject: [PATCH 34/42] Allow collection of tests even if aiohttp isn't installed --- tests/test_aiohttp.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index 88a48d1..40c43fd 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -1,15 +1,22 @@ +try: + from aiohttp import WSMsgType + from graphql_ws.aiohttp import AiohttpConnectionContext, AiohttpSubscriptionServer +except ImportError: # pragma: no cover + WSMsgType = None + from unittest import mock import pytest -from aiohttp import WSMsgType -from graphql_ws.aiohttp import AiohttpConnectionContext, AiohttpSubscriptionServer from graphql_ws.base import ConnectionClosedException +if_aiohttp_installed = pytest.mark.skipif( + WSMsgType is None, reason="aiohttp is not installed" +) + class AsyncMock(mock.Mock): def __call__(self, *args, **kwargs): - async def coro(): return super(AsyncMock, self).__call__(*args, **kwargs) @@ -24,6 +31,7 @@ def mock_ws(): return ws +@if_aiohttp_installed @pytest.mark.asyncio class TestConnectionContext: async def test_receive_good_data(self, mock_ws): @@ -69,5 +77,6 @@ async def test_close(self, mock_ws): mock_ws.close.assert_called_with(code=123) +@if_aiohttp_installed def test_subscription_server_smoke(): AiohttpSubscriptionServer(schema=None) From 56f46a1f1c58050af9844c3db865e03c01229745 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Tue, 30 Mar 2021 09:13:39 +1300 Subject: [PATCH 35/42] Make the python 2 async observer send graphql error for exceptions explicitly returned --- graphql_ws/base_sync.py | 6 +++++- tests/test_base.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/graphql_ws/base_sync.py b/graphql_ws/base_sync.py index 06db900..f6b6c68 100644 --- a/graphql_ws/base_sync.py +++ b/graphql_ws/base_sync.py @@ -65,7 +65,11 @@ def __init__( self.send_message = send_message def on_next(self, value): - self.send_execution_result(self.connection_context, self.op_id, value) + if isinstance(value, Exception): + send_method = self.send_error + else: + send_method = self.send_execution_result + send_method(self.connection_context, self.op_id, value) def on_completed(self): self.send_message(self.connection_context, self.op_id, GQL_COMPLETE) diff --git a/tests/test_base.py b/tests/test_base.py index 5b40ac5..1ce6300 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -8,6 +8,7 @@ import pytest from graphql_ws import base +from graphql_ws.base_sync import SubscriptionObserver def test_not_implemented(): @@ -77,3 +78,35 @@ def test_context_operations(): assert not context.has_operation(1) # Removing a non-existant operation fails silently. context.remove_operation(999) + + +def test_observer_data(): + ws = mock.Mock() + context = base.BaseConnectionContext(ws) + send_result, send_error, send_message = mock.Mock(), mock.Mock(), mock.Mock() + observer = SubscriptionObserver( + connection_context=context, + op_id=1, + send_execution_result=send_result, + send_error=send_error, + send_message=send_message, + ) + observer.on_next('data') + assert send_result.called + assert not send_error.called + + +def test_observer_exception(): + ws = mock.Mock() + context = base.BaseConnectionContext(ws) + send_result, send_error, send_message = mock.Mock(), mock.Mock(), mock.Mock() + observer = SubscriptionObserver( + connection_context=context, + op_id=1, + send_execution_result=send_result, + send_error=send_error, + send_message=send_message, + ) + observer.on_next(TypeError('some bad message')) + assert send_error.called + assert not send_result.called From 5abd858a813ecb09834b92c22184e3e466aa988d Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Tue, 30 Mar 2021 09:15:27 +1300 Subject: [PATCH 36/42] asyncio.wait receiving coroutines is deprecated, create tasks explicitly --- graphql_ws/base_async.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/graphql_ws/base_async.py b/graphql_ws/base_async.py index 0c62481..bc98dc5 100644 --- a/graphql_ws/base_async.py +++ b/graphql_ws/base_async.py @@ -54,7 +54,10 @@ async def resolve( else: items = None if items is not None: - children = [resolve(child, _container=data, _key=key) for key, child in items] + children = [ + asyncio.create_task(resolve(child, _container=data, _key=key)) + for key, child in items + ] if children: await asyncio.wait(children) @@ -90,10 +93,7 @@ def remember_task(self, task): async def unsubscribe(self, op_id): async_iterator = super().unsubscribe(op_id) - if ( - getattr(async_iterator, "future", None) - and async_iterator.future.cancel() - ): + if getattr(async_iterator, "future", None) and async_iterator.future.cancel(): await async_iterator.future async def unsubscribe_all(self): From d2d55a12382e0fc3c938c77fc9fcbd8e72c8ddb7 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Tue, 30 Mar 2021 09:19:52 +1300 Subject: [PATCH 37/42] Tidy up a test warning --- tests/test_base_async.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_base_async.py b/tests/test_base_async.py index d341c18..d1a952b 100644 --- a/tests/test_base_async.py +++ b/tests/test_base_async.py @@ -33,6 +33,7 @@ async def test_terminate(server: TestServer): async def test_send_error(server: TestServer): context = AsyncMock() + context.has_operation = mock.Mock() await server.send_error(connection_context=context, op_id=1, error="test error") context.send.assert_called_with( {"id": 1, "type": "error", "payload": {"message": "test error"}} From f7cb773fdb03b47c19172aa2b5f38ac62a4e5b76 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Tue, 30 Mar 2021 09:20:45 +1300 Subject: [PATCH 38/42] Rename TestServer to avoid it being collected by pytest --- tests/test_base_async.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_base_async.py b/tests/test_base_async.py index d1a952b..d62eda5 100644 --- a/tests/test_base_async.py +++ b/tests/test_base_async.py @@ -15,23 +15,23 @@ async def __call__(self, *args, **kwargs): return super().__call__(*args, **kwargs) -class TestServer(base_async.BaseAsyncSubscriptionServer): +class TstServer(base_async.BaseAsyncSubscriptionServer): def handle(self, *args, **kwargs): - pass + pass # pragma: no cover @pytest.fixture def server(): - return TestServer(schema=None) + return TstServer(schema=None) -async def test_terminate(server: TestServer): +async def test_terminate(server: TstServer): context = AsyncMock() await server.on_connection_terminate(connection_context=context, op_id=1) context.close.assert_called_with(1011) -async def test_send_error(server: TestServer): +async def test_send_error(server: TstServer): context = AsyncMock() context.has_operation = mock.Mock() await server.send_error(connection_context=context, op_id=1, error="test error") From 80890c32124037b533e4a043a899dd60fbae419d Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Tue, 30 Mar 2021 09:21:46 +1300 Subject: [PATCH 39/42] Update test matrix --- setup.cfg | 7 ++----- tests/conftest.py | 2 -- tox.ini | 10 +++++----- 3 files changed, 7 insertions(+), 12 deletions(-) diff --git a/setup.cfg b/setup.cfg index 1e7ea2a..1e85964 100644 --- a/setup.cfg +++ b/setup.cfg @@ -15,15 +15,12 @@ classifiers = License :: OSI Approved :: MIT License Natural Language :: English Programming Language :: Python :: 2 - Programming Language :: Python :: 2.6 Programming Language :: Python :: 2.7 Programming Language :: Python :: 3 - Programming Language :: Python :: 3.3 - Programming Language :: Python :: 3.4 - Programming Language :: Python :: 3.5 Programming Language :: Python :: 3.6 Programming Language :: Python :: 3.7 Programming Language :: Python :: 3.8 + Programming Language :: Python :: 3.9 [options] zip_safe = False @@ -94,4 +91,4 @@ omit = [coverage:report] exclude_lines = pragma: no cover - @abstract \ No newline at end of file + @abstract diff --git a/tests/conftest.py b/tests/conftest.py index fa905b4..595968a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,7 +2,5 @@ if sys.version_info > (3,): collect_ignore = ["test_django_channels.py"] - if sys.version_info < (3, 6): - collect_ignore.append("test_gevent.py") else: collect_ignore = ["test_aiohttp.py", "test_base_async.py"] diff --git a/tox.ini b/tox.ini index 42d13b4..62e2f8b 100644 --- a/tox.ini +++ b/tox.ini @@ -1,15 +1,15 @@ [tox] -envlist = +envlist = coverage_setup - py27, py35, py36, py37, py38, flake8 + py27, py36, py37, py38, py39, flake8 coverage_report [travis] python = - 3.8: py38, flake8 + 3.9: py39, flake8 + 3.8: py38 3.7: py37 3.6: py36 - 3.5: py35 2.7: py27 [testenv] @@ -33,4 +33,4 @@ commands = coverage html coverage xml coverage report --include="tests/*" --fail-under=100 -m - coverage report --omit="tests/*" # --fail-under=90 -m \ No newline at end of file + coverage report --omit="tests/*" # --fail-under=90 -m From 9ced6094a3b2af37695711d2f25316b4e1926bb1 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Tue, 30 Mar 2021 09:35:30 +1300 Subject: [PATCH 40/42] Update travis python versions --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index a3ef963..5104cdc 100644 --- a/.travis.yml +++ b/.travis.yml @@ -11,9 +11,9 @@ deploy: install: pip install -U tox-travis language: python python: +- 3.9 - 3.8 - 3.7 - 3.6 -- 3.5 - 2.7 script: tox From 3adfaa9ce13052c4c96ed5d689ec8e82dec76cb1 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Tue, 30 Mar 2021 10:56:14 +1300 Subject: [PATCH 41/42] Try using a newer travis dist to fix cryptography building issues --- .travis.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.travis.yml b/.travis.yml index 5104cdc..67a356c 100644 --- a/.travis.yml +++ b/.travis.yml @@ -10,6 +10,7 @@ deploy: tags: true install: pip install -U tox-travis language: python +dist: focal python: - 3.9 - 3.8 From 703e4074573b2dca068f2eb36a25a06808ec8698 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Tue, 30 Mar 2021 11:18:45 +1300 Subject: [PATCH 42/42] Use python 3.6 friendly asyncio method --- graphql_ws/base_async.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphql_ws/base_async.py b/graphql_ws/base_async.py index bc98dc5..a21ca5e 100644 --- a/graphql_ws/base_async.py +++ b/graphql_ws/base_async.py @@ -55,7 +55,7 @@ async def resolve( items = None if items is not None: children = [ - asyncio.create_task(resolve(child, _container=data, _key=key)) + asyncio.ensure_future(resolve(child, _container=data, _key=key)) for key, child in items ] if children: