diff --git a/.travis.yml b/.travis.yml index a3ef963..67a356c 100644 --- a/.travis.yml +++ b/.travis.yml @@ -10,10 +10,11 @@ deploy: tags: true install: pip install -U tox-travis language: python +dist: focal python: +- 3.9 - 3.8 - 3.7 - 3.6 -- 3.5 - 2.7 script: tox diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index 01d606e..a2315ad 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -68,7 +68,7 @@ Ready to contribute? Here's how to set up `graphql_ws` for local development. $ mkvirtualenv graphql_ws $ cd graphql_ws/ - $ python setup.py develop + $ pip install -e .[dev] 4. Create a branch for local development:: @@ -79,11 +79,8 @@ Ready to contribute? Here's how to set up `graphql_ws` for local development. 5. When you're done making changes, check that your changes pass flake8 and the tests, including testing other Python versions with tox:: $ flake8 graphql_ws tests - $ python setup.py test or py.test $ tox - To get flake8 and tox, just pip install them into your virtualenv. - 6. Commit your changes and push your branch to GitHub:: $ git add . @@ -101,14 +98,6 @@ Before you submit a pull request, check that it meets these guidelines: 2. If the pull request adds functionality, the docs should be updated. Put your new functionality into a function with a docstring, and add the feature to the list in README.rst. -3. The pull request should work for Python 2.6, 2.7, 3.3, 3.4 and 3.5, and for PyPy. Check +3. The pull request should work for Python 2.7, 3.5, 3.6, 3.7 and 3.8. Check https://travis-ci.org/graphql-python/graphql_ws/pull_requests and make sure that the tests pass for all supported Python versions. - -Tips ----- - -To run a subset of tests:: - -$ py.test tests.test_graphql_ws - diff --git a/README.rst b/README.rst index 90ee500..fb968b6 100644 --- a/README.rst +++ b/README.rst @@ -1,14 +1,23 @@ +========== GraphQL WS ========== -Websocket server for GraphQL subscriptions. +Websocket backend for GraphQL subscriptions. + +Supports the following application servers: + +Python 3 application servers, using asyncio: + + * `aiohttp`_ + * `websockets compatible servers`_ such as Sanic + (via `websockets `__ library) -Currently supports: +Python 2 application servers: + + * `Gevent compatible servers`_ such as Flask + * `Django v1.x`_ + (via `channels v1.x `__) -* `aiohttp `__ -* `Gevent `__ -* Sanic (uses `websockets `__ - library) Installation instructions ========================= @@ -19,21 +28,54 @@ For instaling graphql-ws, just run this command in your shell pip install graphql-ws + Examples --------- +======== + +Python 3 servers +---------------- + +Create a subscribable schema like this: + +.. code:: python + + import asyncio + import graphene + + + class Query(graphene.ObjectType): + hello = graphene.String() + + @staticmethod + def resolve_hello(obj, info, **kwargs): + return "world" + + + class Subscription(graphene.ObjectType): + count_seconds = graphene.Float(up_to=graphene.Int()) + + async def resolve_count_seconds(root, info, up_to): + for i in range(up_to): + yield i + await asyncio.sleep(1.) + yield up_to + + + schema = graphene.Schema(query=Query, subscription=Subscription) aiohttp ~~~~~~~ -For setting up, just plug into your aiohttp server. +Then just plug into your aiohttp server. .. code:: python from graphql_ws.aiohttp import AiohttpSubscriptionServer - + from .schema import schema subscription_server = AiohttpSubscriptionServer(schema) + async def subscriptions(request): ws = web.WebSocketResponse(protocols=('graphql-ws',)) await ws.prepare(request) @@ -47,21 +89,26 @@ For setting up, just plug into your aiohttp server. web.run_app(app, port=8000) -Sanic -~~~~~ +You can see a full example here: +https://github.com/graphql-python/graphql-ws/tree/master/examples/aiohttp + -Works with any framework that uses the websockets library for it’s -websocket implementation. For this example, plug in your Sanic server. +websockets compatible servers +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Works with any framework that uses the websockets library for its websocket +implementation. For this example, plug in your Sanic server. .. code:: python from graphql_ws.websockets_lib import WsLibSubscriptionServer - + from . import schema app = Sanic(__name__) subscription_server = WsLibSubscriptionServer(schema) + @app.websocket('/subscriptions', subprotocols=['graphql-ws']) async def subscriptions(request, ws): await subscription_server.handle(ws) @@ -70,80 +117,73 @@ websocket implementation. For this example, plug in your Sanic server. app.run(host="0.0.0.0", port=8000) -And then, plug into a subscribable schema: + +Python 2 servers +----------------- + +Create a subscribable schema like this: .. code:: python - import asyncio import graphene + from rx import Observable class Query(graphene.ObjectType): - base = graphene.String() + hello = graphene.String() + + @staticmethod + def resolve_hello(obj, info, **kwargs): + return "world" class Subscription(graphene.ObjectType): count_seconds = graphene.Float(up_to=graphene.Int()) - async def resolve_count_seconds(root, info, up_to): - for i in range(up_to): - yield i - await asyncio.sleep(1.) - yield up_to + async def resolve_count_seconds(root, info, up_to=5): + return Observable.interval(1000)\ + .map(lambda i: "{0}".format(i))\ + .take_while(lambda i: int(i) <= up_to) schema = graphene.Schema(query=Query, subscription=Subscription) -You can see a full example here: -https://github.com/graphql-python/graphql-ws/tree/master/examples/aiohttp - -Gevent -~~~~~~ +Gevent compatible servers +~~~~~~~~~~~~~~~~~~~~~~~~~ -For setting up, just plug into your Gevent server. +Then just plug into your Gevent server, for example, Flask: .. code:: python + from flask_sockets import Sockets + from graphql_ws.gevent import GeventSubscriptionServer + from schema import schema + subscription_server = GeventSubscriptionServer(schema) app.app_protocol = lambda environ_path_info: 'graphql-ws' + @sockets.route('/subscriptions') def echo_socket(ws): subscription_server.handle(ws) return [] -And then, plug into a subscribable schema: - -.. code:: python - - import graphene - from rx import Observable - - - class Query(graphene.ObjectType): - base = graphene.String() - - - class Subscription(graphene.ObjectType): - count_seconds = graphene.Float(up_to=graphene.Int()) - - async def resolve_count_seconds(root, info, up_to=5): - return Observable.interval(1000)\ - .map(lambda i: "{0}".format(i))\ - .take_while(lambda i: int(i) <= up_to) - - - schema = graphene.Schema(query=Query, subscription=Subscription) - You can see a full example here: https://github.com/graphql-python/graphql-ws/tree/master/examples/flask_gevent -Django Channels -~~~~~~~~~~~~~~~ +Django v1.x +~~~~~~~~~~~ -First ``pip install channels`` and it to your django apps +For Django v1.x and Django Channels v1.x, setup your schema in ``settings.py`` -Then add the following to your settings.py +.. code:: python + + GRAPHENE = { + 'SCHEMA': 'yourproject.schema.schema' + } + +Then ``pip install "channels<1"`` and it to your django apps, adding the +following to your ``settings.py`` .. code:: python @@ -153,53 +193,9 @@ Then add the following to your settings.py "BACKEND": "asgiref.inmemory.ChannelLayer", "ROUTING": "django_subscriptions.urls.channel_routing", }, - } -Setup your graphql schema - -.. code:: python - - import graphene - from rx import Observable - - - class Query(graphene.ObjectType): - hello = graphene.String() - - def resolve_hello(self, info, **kwargs): - return 'world' - - class Subscription(graphene.ObjectType): - - count_seconds = graphene.Int(up_to=graphene.Int()) - - - def resolve_count_seconds( - root, - info, - up_to=5 - ): - return Observable.interval(1000)\ - .map(lambda i: "{0}".format(i))\ - .take_while(lambda i: int(i) <= up_to) - - - - schema = graphene.Schema( - query=Query, - subscription=Subscription - ) - -Setup your schema in settings.py - -.. code:: python - - GRAPHENE = { - 'SCHEMA': 'path.to.schema' - } - -and finally add the channel routes +And finally add the channel routes .. code:: python @@ -209,3 +205,6 @@ and finally add the channel routes channel_routing = [ route_class(GraphQLSubscriptionConsumer, path=r"^/subscriptions"), ] + +You can see a full example here: +https://github.com/graphql-python/graphql-ws/tree/master/examples/django_subscriptions diff --git a/examples/aiohttp/app.py b/examples/aiohttp/app.py index 56dcaff..336a0c6 100644 --- a/examples/aiohttp/app.py +++ b/examples/aiohttp/app.py @@ -10,24 +10,25 @@ async def graphql_view(request): payload = await request.json() - response = await schema.execute(payload.get('query', ''), return_promise=True) + response = await schema.execute(payload.get("query", ""), return_promise=True) data = {} if response.errors: - data['errors'] = [format_error(e) for e in response.errors] + data["errors"] = [format_error(e) for e in response.errors] if response.data: - data['data'] = response.data + data["data"] = response.data jsondata = json.dumps(data,) - return web.Response(text=jsondata, headers={'Content-Type': 'application/json'}) + return web.Response(text=jsondata, headers={"Content-Type": "application/json"}) async def graphiql_view(request): - return web.Response(text=render_graphiql(), headers={'Content-Type': 'text/html'}) + return web.Response(text=render_graphiql(), headers={"Content-Type": "text/html"}) + subscription_server = AiohttpSubscriptionServer(schema) async def subscriptions(request): - ws = web.WebSocketResponse(protocols=('graphql-ws',)) + ws = web.WebSocketResponse(protocols=("graphql-ws",)) await ws.prepare(request) await subscription_server.handle(ws) @@ -35,9 +36,9 @@ async def subscriptions(request): app = web.Application() -app.router.add_get('/subscriptions', subscriptions) -app.router.add_get('/graphiql', graphiql_view) -app.router.add_get('/graphql', graphql_view) -app.router.add_post('/graphql', graphql_view) +app.router.add_get("/subscriptions", subscriptions) +app.router.add_get("/graphiql", graphiql_view) +app.router.add_get("/graphql", graphql_view) +app.router.add_post("/graphql", graphql_view) web.run_app(app, port=8000) diff --git a/examples/aiohttp/schema.py b/examples/aiohttp/schema.py index 3c23d00..ae107c7 100644 --- a/examples/aiohttp/schema.py +++ b/examples/aiohttp/schema.py @@ -20,14 +20,14 @@ async def resolve_count_seconds(root, info, up_to=5): for i in range(up_to): print("YIELD SECOND", i) yield i - await asyncio.sleep(1.) + await asyncio.sleep(1.0) yield up_to async def resolve_random_int(root, info): i = 0 while True: yield RandomType(seconds=i, random_int=random.randint(0, 500)) - await asyncio.sleep(1.) + await asyncio.sleep(1.0) i += 1 diff --git a/examples/aiohttp/template.py b/examples/aiohttp/template.py index 0b74e96..709f7cf 100644 --- a/examples/aiohttp/template.py +++ b/examples/aiohttp/template.py @@ -1,9 +1,9 @@ - from string import Template def render_graphiql(): - return Template(''' + return Template( + """ @@ -116,10 +116,11 @@ def render_graphiql(): ); -''').substitute( - GRAPHIQL_VERSION='0.10.2', - SUBSCRIPTIONS_TRANSPORT_VERSION='0.7.0', - subscriptionsEndpoint='ws://localhost:8000/subscriptions', +""" + ).substitute( + GRAPHIQL_VERSION="0.10.2", + SUBSCRIPTIONS_TRANSPORT_VERSION="0.7.0", + subscriptionsEndpoint="ws://localhost:8000/subscriptions", # subscriptionsEndpoint='ws://localhost:5000/', - endpointURL='/graphql', + endpointURL="/graphql", ) diff --git a/examples/django_subscriptions/django_subscriptions/asgi.py b/examples/django_subscriptions/django_subscriptions/asgi.py index e6edd7d..35b4d4d 100644 --- a/examples/django_subscriptions/django_subscriptions/asgi.py +++ b/examples/django_subscriptions/django_subscriptions/asgi.py @@ -3,4 +3,4 @@ os.environ.setdefault("DJANGO_SETTINGS_MODULE", "django_subscriptions.settings") -channel_layer = get_channel_layer() \ No newline at end of file +channel_layer = get_channel_layer() diff --git a/examples/django_subscriptions/django_subscriptions/schema.py b/examples/django_subscriptions/django_subscriptions/schema.py index b55d76e..db6893c 100644 --- a/examples/django_subscriptions/django_subscriptions/schema.py +++ b/examples/django_subscriptions/django_subscriptions/schema.py @@ -6,18 +6,19 @@ class Query(graphene.ObjectType): hello = graphene.String() def resolve_hello(self, info, **kwargs): - return 'world' + return "world" + class Subscription(graphene.ObjectType): count_seconds = graphene.Int(up_to=graphene.Int()) - def resolve_count_seconds(root, info, up_to=5): - return Observable.interval(1000)\ - .map(lambda i: "{0}".format(i))\ - .take_while(lambda i: int(i) <= up_to) - + return ( + Observable.interval(1000) + .map(lambda i: "{0}".format(i)) + .take_while(lambda i: int(i) <= up_to) + ) -schema = graphene.Schema(query=Query, subscription=Subscription) \ No newline at end of file +schema = graphene.Schema(query=Query, subscription=Subscription) diff --git a/examples/django_subscriptions/django_subscriptions/settings.py b/examples/django_subscriptions/django_subscriptions/settings.py index 45d0471..7bb3f24 100644 --- a/examples/django_subscriptions/django_subscriptions/settings.py +++ b/examples/django_subscriptions/django_subscriptions/settings.py @@ -20,7 +20,7 @@ # See https://docs.djangoproject.com/en/1.11/howto/deployment/checklist/ # SECURITY WARNING: keep the secret key used in production secret! -SECRET_KEY = 'fa#kz8m$l6)4(np9+-j_-z!voa090mah!s9^4jp=kj!^nwdq^c' +SECRET_KEY = "fa#kz8m$l6)4(np9+-j_-z!voa090mah!s9^4jp=kj!^nwdq^c" # SECURITY WARNING: don't run with debug turned on in production! DEBUG = True @@ -31,53 +31,53 @@ # Application definition INSTALLED_APPS = [ - 'django.contrib.admin', - 'django.contrib.auth', - 'django.contrib.contenttypes', - 'django.contrib.sessions', - 'django.contrib.messages', - 'django.contrib.staticfiles', - 'channels', + "django.contrib.admin", + "django.contrib.auth", + "django.contrib.contenttypes", + "django.contrib.sessions", + "django.contrib.messages", + "django.contrib.staticfiles", + "channels", ] MIDDLEWARE = [ - 'django.middleware.security.SecurityMiddleware', - 'django.contrib.sessions.middleware.SessionMiddleware', - 'django.middleware.common.CommonMiddleware', - 'django.middleware.csrf.CsrfViewMiddleware', - 'django.contrib.auth.middleware.AuthenticationMiddleware', - 'django.contrib.messages.middleware.MessageMiddleware', - 'django.middleware.clickjacking.XFrameOptionsMiddleware', + "django.middleware.security.SecurityMiddleware", + "django.contrib.sessions.middleware.SessionMiddleware", + "django.middleware.common.CommonMiddleware", + "django.middleware.csrf.CsrfViewMiddleware", + "django.contrib.auth.middleware.AuthenticationMiddleware", + "django.contrib.messages.middleware.MessageMiddleware", + "django.middleware.clickjacking.XFrameOptionsMiddleware", ] -ROOT_URLCONF = 'django_subscriptions.urls' +ROOT_URLCONF = "django_subscriptions.urls" TEMPLATES = [ { - 'BACKEND': 'django.template.backends.django.DjangoTemplates', - 'DIRS': [], - 'APP_DIRS': True, - 'OPTIONS': { - 'context_processors': [ - 'django.template.context_processors.debug', - 'django.template.context_processors.request', - 'django.contrib.auth.context_processors.auth', - 'django.contrib.messages.context_processors.messages', + "BACKEND": "django.template.backends.django.DjangoTemplates", + "DIRS": [], + "APP_DIRS": True, + "OPTIONS": { + "context_processors": [ + "django.template.context_processors.debug", + "django.template.context_processors.request", + "django.contrib.auth.context_processors.auth", + "django.contrib.messages.context_processors.messages", ], }, }, ] -WSGI_APPLICATION = 'django_subscriptions.wsgi.application' +WSGI_APPLICATION = "django_subscriptions.wsgi.application" # Database # https://docs.djangoproject.com/en/1.11/ref/settings/#databases DATABASES = { - 'default': { - 'ENGINE': 'django.db.backends.sqlite3', - 'NAME': os.path.join(BASE_DIR, 'db.sqlite3'), + "default": { + "ENGINE": "django.db.backends.sqlite3", + "NAME": os.path.join(BASE_DIR, "db.sqlite3"), } } @@ -87,26 +87,20 @@ AUTH_PASSWORD_VALIDATORS = [ { - 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator', - }, - { - 'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator', - }, - { - 'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator', - }, - { - 'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator', + "NAME": "django.contrib.auth.password_validation.UserAttributeSimilarityValidator", }, + {"NAME": "django.contrib.auth.password_validation.MinimumLengthValidator"}, + {"NAME": "django.contrib.auth.password_validation.CommonPasswordValidator"}, + {"NAME": "django.contrib.auth.password_validation.NumericPasswordValidator"}, ] # Internationalization # https://docs.djangoproject.com/en/1.11/topics/i18n/ -LANGUAGE_CODE = 'en-us' +LANGUAGE_CODE = "en-us" -TIME_ZONE = 'UTC' +TIME_ZONE = "UTC" USE_I18N = True @@ -118,20 +112,16 @@ # Static files (CSS, JavaScript, Images) # https://docs.djangoproject.com/en/1.11/howto/static-files/ -STATIC_URL = '/static/' -CHANNELS_WS_PROTOCOLS = ["graphql-ws", ] +STATIC_URL = "/static/" +CHANNELS_WS_PROTOCOLS = [ + "graphql-ws", +] CHANNEL_LAYERS = { "default": { - "BACKEND": "asgi_redis.RedisChannelLayer", - "CONFIG": { - "hosts": [("localhost", 6379)], - }, + "BACKEND": "asgiref.inmemory.ChannelLayer", "ROUTING": "django_subscriptions.urls.channel_routing", }, - } -GRAPHENE = { - 'SCHEMA': 'django_subscriptions.schema.schema' -} \ No newline at end of file +GRAPHENE = {"SCHEMA": "django_subscriptions.schema.schema"} diff --git a/examples/django_subscriptions/django_subscriptions/template.py b/examples/django_subscriptions/django_subscriptions/template.py index b067ae5..738d9e7 100644 --- a/examples/django_subscriptions/django_subscriptions/template.py +++ b/examples/django_subscriptions/django_subscriptions/template.py @@ -1,9 +1,9 @@ - from string import Template def render_graphiql(): - return Template(''' + return Template( + """ @@ -116,10 +116,11 @@ def render_graphiql(): ); -''').substitute( - GRAPHIQL_VERSION='0.11.10', - SUBSCRIPTIONS_TRANSPORT_VERSION='0.7.0', - subscriptionsEndpoint='ws://localhost:8000/subscriptions', +""" + ).substitute( + GRAPHIQL_VERSION="0.11.10", + SUBSCRIPTIONS_TRANSPORT_VERSION="0.7.0", + subscriptionsEndpoint="ws://localhost:8000/subscriptions", # subscriptionsEndpoint='ws://localhost:5000/', - endpointURL='/graphql', + endpointURL="/graphql", ) diff --git a/examples/django_subscriptions/django_subscriptions/urls.py b/examples/django_subscriptions/django_subscriptions/urls.py index 3848d22..caf790d 100644 --- a/examples/django_subscriptions/django_subscriptions/urls.py +++ b/examples/django_subscriptions/django_subscriptions/urls.py @@ -21,20 +21,21 @@ from graphene_django.views import GraphQLView from django.views.decorators.csrf import csrf_exempt +from channels.routing import route_class +from graphql_ws.django_channels import GraphQLSubscriptionConsumer + def graphiql(request): response = HttpResponse(content=render_graphiql()) return response + urlpatterns = [ - url(r'^admin/', admin.site.urls), - url(r'^graphiql/', graphiql), - url(r'^graphql', csrf_exempt(GraphQLView.as_view(graphiql=True))) + url(r"^admin/", admin.site.urls), + url(r"^graphiql/", graphiql), + url(r"^graphql", csrf_exempt(GraphQLView.as_view(graphiql=True))), ] -from channels.routing import route_class -from graphql_ws.django_channels import GraphQLSubscriptionConsumer - channel_routing = [ route_class(GraphQLSubscriptionConsumer, path=r"^/subscriptions"), -] \ No newline at end of file +] diff --git a/examples/django_subscriptions/requirements.txt b/examples/django_subscriptions/requirements.txt new file mode 100644 index 0000000..557e99f --- /dev/null +++ b/examples/django_subscriptions/requirements.txt @@ -0,0 +1,4 @@ +-e ../.. +django<2 +channels<2 +graphene_django<3 \ No newline at end of file diff --git a/examples/flask_gevent/app.py b/examples/flask_gevent/app.py index dbb0cca..efd145b 100644 --- a/examples/flask_gevent/app.py +++ b/examples/flask_gevent/app.py @@ -1,5 +1,3 @@ -import json - from flask import Flask, make_response from flask_graphql import GraphQLView from flask_sockets import Sockets @@ -14,19 +12,20 @@ sockets = Sockets(app) -@app.route('/graphiql') +@app.route("/graphiql") def graphql_view(): return make_response(render_graphiql()) app.add_url_rule( - '/graphql', view_func=GraphQLView.as_view('graphql', schema=schema, graphiql=False)) + "/graphql", view_func=GraphQLView.as_view("graphql", schema=schema, graphiql=False) +) subscription_server = GeventSubscriptionServer(schema) -app.app_protocol = lambda environ_path_info: 'graphql-ws' +app.app_protocol = lambda environ_path_info: "graphql-ws" -@sockets.route('/subscriptions') +@sockets.route("/subscriptions") def echo_socket(ws): subscription_server.handle(ws) return [] @@ -35,5 +34,6 @@ def echo_socket(ws): if __name__ == "__main__": from gevent import pywsgi from geventwebsocket.handler import WebSocketHandler - server = pywsgi.WSGIServer(('', 5000), app, handler_class=WebSocketHandler) + + server = pywsgi.WSGIServer(("", 5000), app, handler_class=WebSocketHandler) server.serve_forever() diff --git a/examples/flask_gevent/schema.py b/examples/flask_gevent/schema.py index 6e6298c..eb48050 100644 --- a/examples/flask_gevent/schema.py +++ b/examples/flask_gevent/schema.py @@ -19,12 +19,16 @@ class Subscription(graphene.ObjectType): random_int = graphene.Field(RandomType) def resolve_count_seconds(root, info, up_to=5): - return Observable.interval(1000)\ - .map(lambda i: "{0}".format(i))\ - .take_while(lambda i: int(i) <= up_to) + return ( + Observable.interval(1000) + .map(lambda i: "{0}".format(i)) + .take_while(lambda i: int(i) <= up_to) + ) def resolve_random_int(root, info): - return Observable.interval(1000).map(lambda i: RandomType(seconds=i, random_int=random.randint(0, 500))) + return Observable.interval(1000).map( + lambda i: RandomType(seconds=i, random_int=random.randint(0, 500)) + ) schema = graphene.Schema(query=Query, subscription=Subscription) diff --git a/examples/flask_gevent/template.py b/examples/flask_gevent/template.py index 41f52e1..ea0438c 100644 --- a/examples/flask_gevent/template.py +++ b/examples/flask_gevent/template.py @@ -1,9 +1,9 @@ - from string import Template def render_graphiql(): - return Template(''' + return Template( + """ @@ -116,10 +116,11 @@ def render_graphiql(): ); -''').substitute( - GRAPHIQL_VERSION='0.12.0', - SUBSCRIPTIONS_TRANSPORT_VERSION='0.7.0', - subscriptionsEndpoint='ws://localhost:5000/subscriptions', +""" + ).substitute( + GRAPHIQL_VERSION="0.12.0", + SUBSCRIPTIONS_TRANSPORT_VERSION="0.7.0", + subscriptionsEndpoint="ws://localhost:5000/subscriptions", # subscriptionsEndpoint='ws://localhost:5000/', - endpointURL='/graphql', + endpointURL="/graphql", ) diff --git a/examples/websockets_lib/app.py b/examples/websockets_lib/app.py index 0de6988..7638f3d 100644 --- a/examples/websockets_lib/app.py +++ b/examples/websockets_lib/app.py @@ -8,21 +8,23 @@ app = Sanic(__name__) -@app.listener('before_server_start') +@app.listener("before_server_start") def init_graphql(app, loop): - app.add_route(GraphQLView.as_view(schema=schema, - executor=AsyncioExecutor(loop=loop)), - '/graphql') + app.add_route( + GraphQLView.as_view(schema=schema, executor=AsyncioExecutor(loop=loop)), + "/graphql", + ) -@app.route('/graphiql') +@app.route("/graphiql") async def graphiql_view(request): return response.html(render_graphiql()) + subscription_server = WsLibSubscriptionServer(schema) -@app.websocket('/subscriptions', subprotocols=['graphql-ws']) +@app.websocket("/subscriptions", subprotocols=["graphql-ws"]) async def subscriptions(request, ws): await subscription_server.handle(ws) return ws diff --git a/examples/websockets_lib/schema.py b/examples/websockets_lib/schema.py index 3c23d00..ae107c7 100644 --- a/examples/websockets_lib/schema.py +++ b/examples/websockets_lib/schema.py @@ -20,14 +20,14 @@ async def resolve_count_seconds(root, info, up_to=5): for i in range(up_to): print("YIELD SECOND", i) yield i - await asyncio.sleep(1.) + await asyncio.sleep(1.0) yield up_to async def resolve_random_int(root, info): i = 0 while True: yield RandomType(seconds=i, random_int=random.randint(0, 500)) - await asyncio.sleep(1.) + await asyncio.sleep(1.0) i += 1 diff --git a/examples/websockets_lib/template.py b/examples/websockets_lib/template.py index 03587bb..8f007b9 100644 --- a/examples/websockets_lib/template.py +++ b/examples/websockets_lib/template.py @@ -1,9 +1,9 @@ - from string import Template def render_graphiql(): - return Template(''' + return Template( + """ @@ -116,9 +116,10 @@ def render_graphiql(): ); -''').substitute( - GRAPHIQL_VERSION='0.10.2', - SUBSCRIPTIONS_TRANSPORT_VERSION='0.7.0', - subscriptionsEndpoint='ws://localhost:8000/subscriptions', - endpointURL='/graphql', +""" + ).substitute( + GRAPHIQL_VERSION="0.10.2", + SUBSCRIPTIONS_TRANSPORT_VERSION="0.7.0", + subscriptionsEndpoint="ws://localhost:8000/subscriptions", + endpointURL="/graphql", ) diff --git a/graphql_ws/__init__.py b/graphql_ws/__init__.py index 44c7dc3..0ffa258 100644 --- a/graphql_ws/__init__.py +++ b/graphql_ws/__init__.py @@ -3,8 +3,5 @@ """Top-level package for GraphQL WS.""" __author__ = """Syrus Akbary""" -__email__ = 'me@syrusakbary.com' -__version__ = '0.3.1' - - -from .base import BaseConnectionContext, BaseSubscriptionServer # noqa: F401 +__email__ = "me@syrusakbary.com" +__version__ = "0.3.1" diff --git a/graphql_ws/aiohttp.py b/graphql_ws/aiohttp.py index 363ca67..baf8837 100644 --- a/graphql_ws/aiohttp.py +++ b/graphql_ws/aiohttp.py @@ -1,23 +1,13 @@ -from inspect import isawaitable -from asyncio import ensure_future, wait, shield +import json +from asyncio import shield from aiohttp import WSMsgType -from graphql.execution.executors.asyncio import AsyncioExecutor -from .base import ( - ConnectionClosedException, BaseConnectionContext, BaseSubscriptionServer) -from .observable_aiter import setup_observable_extension +from .base import ConnectionClosedException +from .base_async import BaseAsyncConnectionContext, BaseAsyncSubscriptionServer -from .constants import ( - GQL_CONNECTION_ACK, - GQL_CONNECTION_ERROR, - GQL_COMPLETE -) -setup_observable_extension() - - -class AiohttpConnectionContext(BaseConnectionContext): +class AiohttpConnectionContext(BaseAsyncConnectionContext): async def receive(self): msg = await self.ws.receive() if msg.type == WSMsgType.TEXT: @@ -32,7 +22,7 @@ async def receive(self): async def send(self, data): if self.closed: return - await self.ws.send_str(data) + await self.ws.send_str(json.dumps(data)) @property def closed(self): @@ -42,21 +32,10 @@ async def close(self, code): await self.ws.close(code=code) -class AiohttpSubscriptionServer(BaseSubscriptionServer): - def __init__(self, schema, keep_alive=True, loop=None): - self.loop = loop - super().__init__(schema, keep_alive) - - def get_graphql_params(self, *args, **kwargs): - params = super(AiohttpSubscriptionServer, - self).get_graphql_params(*args, **kwargs) - return dict(params, return_promise=True, - executor=AsyncioExecutor(loop=self.loop)) - +class AiohttpSubscriptionServer(BaseAsyncSubscriptionServer): async def _handle(self, ws, request_context=None): connection_context = AiohttpConnectionContext(ws, request_context) await self.on_open(connection_context) - pending = set() while True: try: if connection_context.closed: @@ -64,59 +43,9 @@ async def _handle(self, ws, request_context=None): message = await connection_context.receive() except ConnectionClosedException: break - finally: - if pending: - (_, pending) = await wait(pending, timeout=0, loop=self.loop) - task = ensure_future( - self.on_message(connection_context, message), loop=self.loop) - pending.add(task) - - self.on_close(connection_context) - for task in pending: - task.cancel() + self.on_message(connection_context, message) + await self.on_close(connection_context) async def handle(self, ws, request_context=None): await shield(self._handle(ws, request_context), loop=self.loop) - - async def on_open(self, connection_context): - pass - - def on_close(self, connection_context): - remove_operations = list(connection_context.operations.keys()) - for op_id in remove_operations: - self.unsubscribe(connection_context, op_id) - - async def on_connect(self, connection_context, payload): - pass - - async def on_connection_init(self, connection_context, op_id, payload): - try: - await self.on_connect(connection_context, payload) - await self.send_message(connection_context, op_type=GQL_CONNECTION_ACK) - except Exception as e: - await self.send_error(connection_context, op_id, e, GQL_CONNECTION_ERROR) - await connection_context.close(1011) - - async def on_start(self, connection_context, op_id, params): - execution_result = self.execute( - connection_context.request_context, params) - - if isawaitable(execution_result): - execution_result = await execution_result - - if not hasattr(execution_result, '__aiter__'): - await self.send_execution_result( - connection_context, op_id, execution_result) - else: - iterator = await execution_result.__aiter__() - connection_context.register_operation(op_id, iterator) - async for single_result in iterator: - if not connection_context.has_operation(op_id): - break - await self.send_execution_result( - connection_context, op_id, single_result) - await self.send_message(connection_context, op_id, GQL_COMPLETE) - - async def on_stop(self, connection_context, op_id): - self.unsubscribe(connection_context, op_id) diff --git a/graphql_ws/base.py b/graphql_ws/base.py index f3aa1e7..31ad657 100644 --- a/graphql_ws/base.py +++ b/graphql_ws/base.py @@ -1,16 +1,16 @@ import json from collections import OrderedDict -from graphql import graphql, format_error +from graphql import format_error, graphql from .constants import ( + GQL_CONNECTION_ERROR, GQL_CONNECTION_INIT, GQL_CONNECTION_TERMINATE, + GQL_DATA, + GQL_ERROR, GQL_START, GQL_STOP, - GQL_ERROR, - GQL_CONNECTION_ERROR, - GQL_DATA ) @@ -34,7 +34,20 @@ def get_operation(self, op_id): return self.operations[op_id] def remove_operation(self, op_id): - del self.operations[op_id] + try: + return self.operations.pop(op_id) + except KeyError: + return + + def unsubscribe(self, op_id): + async_iterator = self.remove_operation(op_id) + if hasattr(async_iterator, 'dispose'): + async_iterator.dispose() + return async_iterator + + def unsubscribe_all(self): + for op_id in list(self.operations): + self.unsubscribe(op_id) def receive(self): raise NotImplementedError("receive method not implemented") @@ -51,33 +64,19 @@ def close(self, code): class BaseSubscriptionServer(object): + graphql_executor = None def __init__(self, schema, keep_alive=True): self.schema = schema self.keep_alive = keep_alive - def get_graphql_params(self, connection_context, payload): - return { - 'request_string': payload.get('query'), - 'variable_values': payload.get('variables'), - 'operation_name': payload.get('operationName'), - 'context_value': payload.get('context'), - } - - def build_message(self, id, op_type, payload): - message = {} - if id is not None: - message['id'] = id - if op_type is not None: - message['type'] = op_type - if payload is not None: - message['payload'] = payload - return message + def execute(self, params): + return graphql(self.schema, **dict(params, allow_subscriptions=True)) def process_message(self, connection_context, parsed_message): - op_id = parsed_message.get('id') - op_type = parsed_message.get('type') - payload = parsed_message.get('payload') + op_id = parsed_message.get("id") + op_type = parsed_message.get("type") + payload = parsed_message.get("payload") if op_type == GQL_CONNECTION_INIT: return self.on_connection_init(connection_context, op_id, payload) @@ -87,27 +86,59 @@ def process_message(self, connection_context, parsed_message): elif op_type == GQL_START: assert isinstance(payload, dict), "The payload must be a dict" - params = self.get_graphql_params(connection_context, payload) - if not isinstance(params, dict): - error = Exception( - "Invalid params returned from get_graphql_params!" - " Return values must be a dict.") - return self.send_error(connection_context, op_id, error) - - # If we already have a subscription with this id, unsubscribe from - # it first - if connection_context.has_operation(op_id): - self.unsubscribe(connection_context, op_id) - return self.on_start(connection_context, op_id, params) elif op_type == GQL_STOP: return self.on_stop(connection_context, op_id) else: - return self.send_error(connection_context, op_id, Exception( - "Invalid message type: {}.".format(op_type))) + return self.send_error( + connection_context, + op_id, + Exception("Invalid message type: {}.".format(op_type)), + ) + + def on_connection_init(self, connection_context, op_id, payload): + raise NotImplementedError("on_connection_init method not implemented") + + def on_connection_terminate(self, connection_context, op_id): + return connection_context.close(1011) + + def get_graphql_params(self, connection_context, payload): + context = payload.get("context", connection_context.request_context) + return { + "request_string": payload.get("query"), + "variable_values": payload.get("variables"), + "operation_name": payload.get("operationName"), + "context_value": context, + "executor": self.graphql_executor(), + } + + def on_open(self, connection_context): + raise NotImplementedError("on_open method not implemented") + + def on_stop(self, connection_context, op_id): + return connection_context.unsubscribe(op_id) + + def on_close(self, connection_context): + return connection_context.unsubscribe_all() + + def send_message(self, connection_context, op_id=None, op_type=None, payload=None): + if op_id is None or connection_context.has_operation(op_id): + message = self.build_message(op_id, op_type, payload) + return connection_context.send(message) + + def build_message(self, id, op_type, payload): + message = {} + if id is not None: + message["id"] = id + if op_type is not None: + message["type"] = op_type + if payload is not None: + message["payload"] = payload + assert message, "You need to send at least one thing" + return message def send_execution_result(self, connection_context, op_id, execution_result): result = self.execution_result_to_dict(execution_result) @@ -116,86 +147,34 @@ def send_execution_result(self, connection_context, op_id, execution_result): def execution_result_to_dict(self, execution_result): result = OrderedDict() if execution_result.data: - result['data'] = execution_result.data + result["data"] = execution_result.data if execution_result.errors: - result['errors'] = [format_error(error) - for error in execution_result.errors] + result["errors"] = [ + format_error(error) for error in execution_result.errors + ] return result - def send_message(self, connection_context, op_id=None, op_type=None, payload=None): - message = self.build_message(op_id, op_type, payload) - assert message, "You need to send at least one thing" - json_message = json.dumps(message) - return connection_context.send(json_message) - def send_error(self, connection_context, op_id, error, error_type=None): if error_type is None: error_type = GQL_ERROR assert error_type in [GQL_CONNECTION_ERROR, GQL_ERROR], ( - 'error_type should be one of the allowed error messages' - ' GQL_CONNECTION_ERROR or GQL_ERROR' - ) - - error_payload = { - 'message': str(error) - } - - return self.send_message( - connection_context, - op_id, - error_type, - error_payload + "error_type should be one of the allowed error messages" + " GQL_CONNECTION_ERROR or GQL_ERROR" ) - def unsubscribe(self, connection_context, op_id): - if connection_context.has_operation(op_id): - # Close async iterator - connection_context.get_operation(op_id).dispose() - # Close operation - connection_context.remove_operation(op_id) - self.on_operation_complete(connection_context, op_id) + error_payload = {"message": str(error)} - def on_operation_complete(self, connection_context, op_id): - pass - - def on_connection_terminate(self, connection_context, op_id): - return connection_context.close(1011) - - def execute(self, request_context, params): - return graphql( - self.schema, **dict(params, allow_subscriptions=True)) - - def handle(self, ws, request_context=None): - raise NotImplementedError("handle method not implemented") + return self.send_message(connection_context, op_id, error_type, error_payload) def on_message(self, connection_context, message): try: if not isinstance(message, dict): parsed_message = json.loads(message) - assert isinstance( - parsed_message, dict), "Payload must be an object." + assert isinstance(parsed_message, dict), "Payload must be an object." else: parsed_message = message except Exception as e: return self.send_error(connection_context, None, e) return self.process_message(connection_context, parsed_message) - - def on_open(self, connection_context): - raise NotImplementedError("on_open method not implemented") - - def on_connect(self, connection_context, payload): - raise NotImplementedError("on_connect method not implemented") - - def on_close(self, connection_context): - raise NotImplementedError("on_close method not implemented") - - def on_connection_init(self, connection_context, op_id, payload): - raise NotImplementedError("on_connection_init method not implemented") - - def on_stop(self, connection_context, op_id): - raise NotImplementedError("on_stop method not implemented") - - def on_start(self, connection_context, op_id, params): - raise NotImplementedError("on_start method not implemented") diff --git a/graphql_ws/base_async.py b/graphql_ws/base_async.py new file mode 100644 index 0000000..a21ca5e --- /dev/null +++ b/graphql_ws/base_async.py @@ -0,0 +1,189 @@ +import asyncio +import inspect +from abc import ABC, abstractmethod +from types import CoroutineType, GeneratorType +from typing import Any, Dict, List, Union +from weakref import WeakSet + +from graphql.execution.executors.asyncio import AsyncioExecutor +from promise import Promise + +from graphql_ws import base + +from .constants import GQL_COMPLETE, GQL_CONNECTION_ACK, GQL_CONNECTION_ERROR +from .observable_aiter import setup_observable_extension + +setup_observable_extension() +CO_ITERABLE_COROUTINE = inspect.CO_ITERABLE_COROUTINE + + +# Copied from graphql-core v3.1.0 (graphql/pyutils/is_awaitable.py) +def is_awaitable(value: Any) -> bool: + """Return true if object can be passed to an ``await`` expression. + Instead of testing if the object is an instance of abc.Awaitable, it checks + the existence of an `__await__` attribute. This is much faster. + """ + return ( + # check for coroutine objects + isinstance(value, CoroutineType) + # check for old-style generator based coroutine objects + or isinstance(value, GeneratorType) + and bool(value.gi_code.co_flags & CO_ITERABLE_COROUTINE) + # check for other awaitables (e.g. futures) + or hasattr(value, "__await__") + ) + + +async def resolve( + data: Any, _container: Union[List, Dict] = None, _key: Union[str, int] = None +) -> None: + """ + Recursively wait on any awaitable children of a data element and resolve any + Promises. + """ + if is_awaitable(data): + data = await data + if isinstance(data, Promise): + data = data.value # type: Any + if _container is not None: + _container[_key] = data + if isinstance(data, dict): + items = data.items() + elif isinstance(data, list): + items = enumerate(data) + else: + items = None + if items is not None: + children = [ + asyncio.ensure_future(resolve(child, _container=data, _key=key)) + for key, child in items + ] + if children: + await asyncio.wait(children) + + +class BaseAsyncConnectionContext(base.BaseConnectionContext, ABC): + def __init__(self, ws, request_context=None): + super().__init__(ws, request_context=request_context) + self.pending_tasks = WeakSet() + + @abstractmethod + async def receive(self): + raise NotImplementedError("receive method not implemented") + + @abstractmethod + async def send(self, data): + ... + + @property + @abstractmethod + def closed(self): + ... + + @abstractmethod + async def close(self, code): + ... + + def remember_task(self, task): + self.pending_tasks.add(task) + # Clear completed tasks + self.pending_tasks -= WeakSet( + task for task in self.pending_tasks if task.done() + ) + + async def unsubscribe(self, op_id): + async_iterator = super().unsubscribe(op_id) + if getattr(async_iterator, "future", None) and async_iterator.future.cancel(): + await async_iterator.future + + async def unsubscribe_all(self): + awaitables = [self.unsubscribe(op_id) for op_id in list(self.operations)] + for task in self.pending_tasks: + task.cancel() + awaitables.append(task) + if awaitables: + try: + await asyncio.gather(*awaitables) + except asyncio.CancelledError: + pass + + +class BaseAsyncSubscriptionServer(base.BaseSubscriptionServer, ABC): + graphql_executor = AsyncioExecutor + + def __init__(self, schema, keep_alive=True, loop=None): + self.loop = loop + super().__init__(schema, keep_alive) + + @abstractmethod + async def handle(self, ws, request_context=None): + ... + + def process_message(self, connection_context, parsed_message): + task = asyncio.ensure_future( + super().process_message(connection_context, parsed_message), loop=self.loop + ) + connection_context.remember_task(task) + return task + + async def on_open(self, connection_context): + pass + + async def on_connect(self, connection_context, payload): + pass + + async def on_connection_init(self, connection_context, op_id, payload): + try: + await self.on_connect(connection_context, payload) + await self.send_message(connection_context, op_type=GQL_CONNECTION_ACK) + except Exception as e: + await self.send_error(connection_context, op_id, e, GQL_CONNECTION_ERROR) + await connection_context.close(1011) + + async def on_start(self, connection_context, op_id, params): + # Attempt to unsubscribe first in case we already have a subscription + # with this id. + await connection_context.unsubscribe(op_id) + + execution_result = self.execute(params) + + connection_context.register_operation(op_id, execution_result) + if hasattr(execution_result, "__aiter__"): + iterator = await execution_result.__aiter__() + connection_context.register_operation(op_id, iterator) + try: + async for single_result in iterator: + if not connection_context.has_operation(op_id): + break + await self.send_execution_result( + connection_context, op_id, single_result + ) + except Exception as e: + await self.send_error(connection_context, op_id, e) + else: + try: + if is_awaitable(execution_result): + execution_result = await execution_result + await self.send_execution_result( + connection_context, op_id, execution_result + ) + except Exception as e: + await self.send_error(connection_context, op_id, e) + await self.send_message(connection_context, op_id, GQL_COMPLETE) + await connection_context.unsubscribe(op_id) + await self.on_operation_complete(connection_context, op_id) + + async def send_message( + self, connection_context, op_id=None, op_type=None, payload=None + ): + if op_id is None or connection_context.has_operation(op_id): + message = self.build_message(op_id, op_type, payload) + return await connection_context.send(message) + + async def on_operation_complete(self, connection_context, op_id): + pass + + async def send_execution_result(self, connection_context, op_id, execution_result): + # Resolve any pending promises + await resolve(execution_result.data) + await super().send_execution_result(connection_context, op_id, execution_result) diff --git a/graphql_ws/base_sync.py b/graphql_ws/base_sync.py new file mode 100644 index 0000000..f6b6c68 --- /dev/null +++ b/graphql_ws/base_sync.py @@ -0,0 +1,80 @@ +from graphql.execution.executors.sync import SyncExecutor +from rx import Observable, Observer + +from .base import BaseSubscriptionServer +from .constants import GQL_COMPLETE, GQL_CONNECTION_ACK, GQL_CONNECTION_ERROR + + +class BaseSyncSubscriptionServer(BaseSubscriptionServer): + graphql_executor = SyncExecutor + + def on_operation_complete(self, connection_context, op_id): + pass + + def handle(self, ws, request_context=None): + raise NotImplementedError("handle method not implemented") + + def on_open(self, connection_context): + pass + + def on_connect(self, connection_context, payload): + pass + + def on_connection_init(self, connection_context, op_id, payload): + try: + self.on_connect(connection_context, payload) + self.send_message(connection_context, op_type=GQL_CONNECTION_ACK) + + except Exception as e: + self.send_error(connection_context, op_id, e, GQL_CONNECTION_ERROR) + connection_context.close(1011) + + def on_start(self, connection_context, op_id, params): + # Attempt to unsubscribe first in case we already have a subscription + # with this id. + connection_context.unsubscribe(op_id) + try: + execution_result = self.execute(params) + assert isinstance( + execution_result, Observable + ), "A subscription must return an observable" + disposable = execution_result.subscribe( + SubscriptionObserver( + connection_context, + op_id, + self.send_execution_result, + self.send_error, + self.send_message, + ) + ) + connection_context.register_operation(op_id, disposable) + + except Exception as e: + self.send_error(connection_context, op_id, e) + self.send_message(connection_context, op_id, GQL_COMPLETE) + + +class SubscriptionObserver(Observer): + def __init__( + self, connection_context, op_id, send_execution_result, send_error, send_message + ): + self.connection_context = connection_context + self.op_id = op_id + self.send_execution_result = send_execution_result + self.send_error = send_error + self.send_message = send_message + + def on_next(self, value): + if isinstance(value, Exception): + send_method = self.send_error + else: + send_method = self.send_execution_result + send_method(self.connection_context, self.op_id, value) + + def on_completed(self): + self.send_message(self.connection_context, self.op_id, GQL_COMPLETE) + self.connection_context.remove_operation(self.op_id) + + def on_error(self, error): + self.send_error(self.connection_context, self.op_id, error) + self.on_completed() diff --git a/graphql_ws/constants.py b/graphql_ws/constants.py index 4f9d2f1..8b57a60 100644 --- a/graphql_ws/constants.py +++ b/graphql_ws/constants.py @@ -1,15 +1,15 @@ -GRAPHQL_WS = 'graphql-ws' +GRAPHQL_WS = "graphql-ws" WS_PROTOCOL = GRAPHQL_WS -GQL_CONNECTION_INIT = 'connection_init' # Client -> Server -GQL_CONNECTION_ACK = 'connection_ack' # Server -> Client -GQL_CONNECTION_ERROR = 'connection_error' # Server -> Client +GQL_CONNECTION_INIT = "connection_init" # Client -> Server +GQL_CONNECTION_ACK = "connection_ack" # Server -> Client +GQL_CONNECTION_ERROR = "connection_error" # Server -> Client # NOTE: This one here don't follow the standard due to connection optimization -GQL_CONNECTION_TERMINATE = 'connection_terminate' # Client -> Server -GQL_CONNECTION_KEEP_ALIVE = 'ka' # Server -> Client -GQL_START = 'start' # Client -> Server -GQL_DATA = 'data' # Server -> Client -GQL_ERROR = 'error' # Server -> Client -GQL_COMPLETE = 'complete' # Server -> Client -GQL_STOP = 'stop' # Client -> Server +GQL_CONNECTION_TERMINATE = "connection_terminate" # Client -> Server +GQL_CONNECTION_KEEP_ALIVE = "ka" # Server -> Client +GQL_START = "start" # Client -> Server +GQL_DATA = "data" # Server -> Client +GQL_ERROR = "error" # Server -> Client +GQL_COMPLETE = "complete" # Server -> Client +GQL_STOP = "stop" # Client -> Server diff --git a/graphql_ws/django_channels.py b/graphql_ws/django_channels.py index 61a7247..ddba58d 100644 --- a/graphql_ws/django_channels.py +++ b/graphql_ws/django_channels.py @@ -1,129 +1,47 @@ import json -from rx import Observer, Observable -from graphql.execution.executors.sync import SyncExecutor - from channels.generic.websockets import JsonWebsocketConsumer from graphene_django.settings import graphene_settings -from .base import BaseConnectionContext, BaseSubscriptionServer -from .constants import GQL_CONNECTION_ACK, GQL_CONNECTION_ERROR +from .base import BaseConnectionContext +from .base_sync import BaseSyncSubscriptionServer class DjangoChannelConnectionContext(BaseConnectionContext): - - def __init__(self, message, request_context=None): - self.message = message - self.operations = {} - self.request_context = request_context + def __init__(self, message): + super(DjangoChannelConnectionContext, self).__init__( + message.reply_channel, + request_context={"user": message.user, "session": message.http_session}, + ) def send(self, data): - self.message.reply_channel.send(data) + self.ws.send({"text": json.dumps(data)}) def close(self, reason): - data = { - 'close': True, - 'text': reason - } - self.message.reply_channel.send(data) - + data = {"close": True, "text": reason} + self.ws.send(data) -class DjangoChannelSubscriptionServer(BaseSubscriptionServer): - - def get_graphql_params(self, *args, **kwargs): - params = super(DjangoChannelSubscriptionServer, - self).get_graphql_params(*args, **kwargs) - return dict(params, executor=SyncExecutor()) +class DjangoChannelSubscriptionServer(BaseSyncSubscriptionServer): def handle(self, message, connection_context): self.on_message(connection_context, message) - def send_message(self, connection_context, op_id=None, op_type=None, payload=None): - message = {} - if op_id is not None: - message['id'] = op_id - if op_type is not None: - message['type'] = op_type - if payload is not None: - message['payload'] = payload - - assert message, "You need to send at least one thing" - return connection_context.send({'text': json.dumps(message)}) - - def on_open(self, connection_context): - pass - - def on_connect(self, connection_context, payload): - pass - - def on_connection_init(self, connection_context, op_id, payload): - try: - self.on_connect(connection_context, payload) - self.send_message(connection_context, op_type=GQL_CONNECTION_ACK) - - except Exception as e: - self.send_error(connection_context, op_id, e, GQL_CONNECTION_ERROR) - connection_context.close(1011) - def on_start(self, connection_context, op_id, params): - try: - execution_result = self.execute( - connection_context.request_context, params) - assert isinstance(execution_result, Observable), \ - "A subscription must return an observable" - execution_result.subscribe(SubscriptionObserver( - connection_context, - op_id, - self.send_execution_result, - self.send_error, - self.on_close - )) - except Exception as e: - self.send_error(connection_context, op_id, str(e)) - - def on_close(self, connection_context): - remove_operations = list(connection_context.operations.keys()) - for op_id in remove_operations: - self.unsubscribe(connection_context, op_id) - - def on_stop(self, connection_context, op_id): - self.unsubscribe(connection_context, op_id) +subscription_server = DjangoChannelSubscriptionServer(graphene_settings.SCHEMA) class GraphQLSubscriptionConsumer(JsonWebsocketConsumer): http_user_and_session = True strict_ordering = True - def connect(self, message, **_kwargs): + def connect(self, message, **kwargs): message.reply_channel.send({"accept": True}) - def receive(self, content, **_kwargs): + def receive(self, content, **kwargs): """ Called when a message is received with either text or bytes filled out. """ - self.connection_context = DjangoChannelConnectionContext(self.message) - self.subscription_server = DjangoChannelSubscriptionServer( - graphene_settings.SCHEMA) - self.subscription_server.on_open(self.connection_context) - self.subscription_server.handle(content, self.connection_context) - - -class SubscriptionObserver(Observer): - - def __init__(self, connection_context, op_id, - send_execution_result, send_error, on_close): - self.connection_context = connection_context - self.op_id = op_id - self.send_execution_result = send_execution_result - self.send_error = send_error - self.on_close = on_close - - def on_next(self, value): - self.send_execution_result(self.connection_context, self.op_id, value) - - def on_completed(self): - self.on_close(self.connection_context) - - def on_error(self, error): - self.send_error(self.connection_context, self.op_id, error) + context = DjangoChannelConnectionContext(self.message) + subscription_server.on_open(context) + subscription_server.handle(content, context) diff --git a/graphql_ws/gevent.py b/graphql_ws/gevent.py index aadbe64..b7d6849 100644 --- a/graphql_ws/gevent.py +++ b/graphql_ws/gevent.py @@ -1,15 +1,15 @@ from __future__ import absolute_import -from rx import Observer, Observable -from graphql.execution.executors.sync import SyncExecutor +import json from .base import ( - ConnectionClosedException, BaseConnectionContext, BaseSubscriptionServer) -from .constants import GQL_CONNECTION_ACK, GQL_CONNECTION_ERROR + BaseConnectionContext, + ConnectionClosedException, +) +from .base_sync import BaseSyncSubscriptionServer class GeventConnectionContext(BaseConnectionContext): - def receive(self): msg = self.ws.receive() return msg @@ -17,7 +17,7 @@ def receive(self): def send(self, data): if self.closed: return - self.ws.send(data) + self.ws.send(json.dumps(data)) @property def closed(self): @@ -27,13 +27,7 @@ def close(self, code): self.ws.close(code) -class GeventSubscriptionServer(BaseSubscriptionServer): - - def get_graphql_params(self, *args, **kwargs): - params = super(GeventSubscriptionServer, - self).get_graphql_params(*args, **kwargs) - return dict(params, executor=SyncExecutor()) - +class GeventSubscriptionServer(BaseSyncSubscriptionServer): def handle(self, ws, request_context=None): connection_context = GeventConnectionContext(ws, request_context) self.on_open(connection_context) @@ -46,62 +40,3 @@ def handle(self, ws, request_context=None): self.on_close(connection_context) return self.on_message(connection_context, message) - - def on_open(self, connection_context): - pass - - def on_connect(self, connection_context, payload): - pass - - def on_close(self, connection_context): - remove_operations = list(connection_context.operations.keys()) - for op_id in remove_operations: - self.unsubscribe(connection_context, op_id) - - def on_connection_init(self, connection_context, op_id, payload): - try: - self.on_connect(connection_context, payload) - self.send_message(connection_context, op_type=GQL_CONNECTION_ACK) - - except Exception as e: - self.send_error(connection_context, op_id, e, GQL_CONNECTION_ERROR) - connection_context.close(1011) - - def on_start(self, connection_context, op_id, params): - try: - execution_result = self.execute( - connection_context.request_context, params) - assert isinstance(execution_result, Observable), \ - "A subscription must return an observable" - execution_result.subscribe(SubscriptionObserver( - connection_context, - op_id, - self.send_execution_result, - self.send_error, - self.on_close - )) - except Exception as e: - self.send_error(connection_context, op_id, str(e)) - - def on_stop(self, connection_context, op_id): - self.unsubscribe(connection_context, op_id) - - -class SubscriptionObserver(Observer): - - def __init__(self, connection_context, op_id, - send_execution_result, send_error, on_close): - self.connection_context = connection_context - self.op_id = op_id - self.send_execution_result = send_execution_result - self.send_error = send_error - self.on_close = on_close - - def on_next(self, value): - self.send_execution_result(self.connection_context, self.op_id, value) - - def on_completed(self): - self.on_close(self.connection_context) - - def on_error(self, error): - self.send_error(self.connection_context, self.op_id, error) diff --git a/graphql_ws/observable_aiter.py b/graphql_ws/observable_aiter.py index 0bd1a59..424d95f 100644 --- a/graphql_ws/observable_aiter.py +++ b/graphql_ws/observable_aiter.py @@ -1,7 +1,7 @@ from asyncio import Future -from rx.internal import extensionmethod from rx.core import Observable +from rx.internal import extensionmethod async def __aiter__(self): @@ -13,15 +13,11 @@ def __init__(self): self.future = Future() self.disposable = source.materialize().subscribe(self.on_next) - # self.disposed = False def __aiter__(self): return self def dispose(self): - # self.future.cancel() - # self.disposed = True - # self.future.set_exception(StopAsyncIteration) self.disposable.dispose() def feeder(self): @@ -30,11 +26,11 @@ def feeder(self): notification = self.notifications.pop(0) kind = notification.kind - if kind == 'N': + if kind == "N": self.future.set_result(notification.value) - if kind == 'E': + if kind == "E": self.future.set_exception(notification.exception) - if kind == 'C': + if kind == "C": self.future.set_exception(StopAsyncIteration) def on_next(self, notification): @@ -42,8 +38,6 @@ def on_next(self, notification): self.feeder() async def __anext__(self): - # if self.disposed: - # raise StopAsyncIteration self.feeder() value = await self.future @@ -53,38 +47,5 @@ async def __anext__(self): return AIterator() -# def __aiter__(self, sentinel=None): -# loop = get_event_loop() -# future = [Future()] -# notifications = [] - -# def feeder(): -# if not len(notifications) or future[0].done(): -# return -# notification = notifications.pop(0) -# if notification.kind == "E": -# future[0].set_exception(notification.exception) -# elif notification.kind == "C": -# future[0].set_exception(StopIteration(sentinel)) -# else: -# future[0].set_result(notification.value) - -# def on_next(value): -# """Takes on_next values and appends them to the notification queue""" -# notifications.append(value) -# loop.call_soon(feeder) - -# self.materialize().subscribe(on_next) - -# @asyncio.coroutine -# def gen(): -# """Generator producing futures""" -# loop.call_soon(feeder) -# future[0] = Future() -# return future[0] - -# return gen - - def setup_observable_extension(): extensionmethod(Observable)(__aiter__) diff --git a/graphql_ws/websockets_lib.py b/graphql_ws/websockets_lib.py index 7e78d5d..c0adc67 100644 --- a/graphql_ws/websockets_lib.py +++ b/graphql_ws/websockets_lib.py @@ -1,19 +1,13 @@ -from inspect import isawaitable -from asyncio import ensure_future, wait, shield -from websockets import ConnectionClosed -from graphql.execution.executors.asyncio import AsyncioExecutor - -from .base import ( - ConnectionClosedException, BaseConnectionContext, BaseSubscriptionServer) -from .observable_aiter import setup_observable_extension +import json +from asyncio import shield -from .constants import ( - GQL_CONNECTION_ACK, GQL_CONNECTION_ERROR, GQL_COMPLETE) +from websockets import ConnectionClosed -setup_observable_extension() +from .base import ConnectionClosedException +from .base_async import BaseAsyncConnectionContext, BaseAsyncSubscriptionServer -class WsLibConnectionContext(BaseConnectionContext): +class WsLibConnectionContext(BaseAsyncConnectionContext): async def receive(self): try: msg = await self.ws.recv() @@ -24,7 +18,7 @@ async def receive(self): async def send(self, data): if self.closed: return - await self.ws.send(data) + await self.ws.send(json.dumps(data)) @property def closed(self): @@ -34,21 +28,10 @@ async def close(self, code): await self.ws.close(code) -class WsLibSubscriptionServer(BaseSubscriptionServer): - def __init__(self, schema, keep_alive=True, loop=None): - self.loop = loop - super().__init__(schema, keep_alive) - - def get_graphql_params(self, *args, **kwargs): - params = super(WsLibSubscriptionServer, - self).get_graphql_params(*args, **kwargs) - return dict(params, return_promise=True, - executor=AsyncioExecutor(loop=self.loop)) - +class WsLibSubscriptionServer(BaseAsyncSubscriptionServer): async def _handle(self, ws, request_context): connection_context = WsLibConnectionContext(ws, request_context) await self.on_open(connection_context) - pending = set() while True: try: if connection_context.closed: @@ -56,61 +39,9 @@ async def _handle(self, ws, request_context): message = await connection_context.receive() except ConnectionClosedException: break - finally: - if pending: - (_, pending) = await wait(pending, timeout=0, loop=self.loop) - - task = ensure_future( - self.on_message(connection_context, message), loop=self.loop) - pending.add(task) - self.on_close(connection_context) - for task in pending: - task.cancel() + self.on_message(connection_context, message) + await self.on_close(connection_context) async def handle(self, ws, request_context=None): await shield(self._handle(ws, request_context), loop=self.loop) - - async def on_open(self, connection_context): - pass - - def on_close(self, connection_context): - remove_operations = list(connection_context.operations.keys()) - for op_id in remove_operations: - self.unsubscribe(connection_context, op_id) - - async def on_connect(self, connection_context, payload): - pass - - async def on_connection_init(self, connection_context, op_id, payload): - try: - await self.on_connect(connection_context, payload) - await self.send_message( - connection_context, op_type=GQL_CONNECTION_ACK) - except Exception as e: - await self.send_error( - connection_context, op_id, e, GQL_CONNECTION_ERROR) - await connection_context.close(1011) - - async def on_start(self, connection_context, op_id, params): - execution_result = self.execute( - connection_context.request_context, params) - - if isawaitable(execution_result): - execution_result = await execution_result - - if not hasattr(execution_result, '__aiter__'): - await self.send_execution_result( - connection_context, op_id, execution_result) - else: - iterator = await execution_result.__aiter__() - connection_context.register_operation(op_id, iterator) - async for single_result in iterator: - if not connection_context.has_operation(op_id): - break - await self.send_execution_result( - connection_context, op_id, single_result) - await self.send_message(connection_context, op_id, GQL_COMPLETE) - - async def on_stop(self, connection_context, op_id): - self.unsubscribe(connection_context, op_id) diff --git a/setup.cfg b/setup.cfg index df50b23..1e85964 100644 --- a/setup.cfg +++ b/setup.cfg @@ -2,7 +2,7 @@ [metadata] name = graphql-ws version = 0.3.1 -description = Websocket server for GraphQL subscriptions +description = Websocket backend for GraphQL subscriptions long_description = file: README.rst, CHANGES.rst author = Syrus Akbary author_email = me@syrusakbary.com @@ -15,15 +15,12 @@ classifiers = License :: OSI Approved :: MIT License Natural Language :: English Programming Language :: Python :: 2 - Programming Language :: Python :: 2.6 Programming Language :: Python :: 2.7 Programming Language :: Python :: 3 - Programming Language :: Python :: 3.3 - Programming Language :: Python :: 3.4 - Programming Language :: Python :: 3.5 Programming Language :: Python :: 3.6 Programming Language :: Python :: 3.7 Programming Language :: Python :: 3.8 + Programming Language :: Python :: 3.9 [options] zip_safe = False @@ -90,3 +87,8 @@ ignore = W503 [coverage:run] omit = .tox/* + +[coverage:report] +exclude_lines = + pragma: no cover + @abstract diff --git a/tests/conftest.py b/tests/conftest.py index e551557..595968a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,7 +2,5 @@ if sys.version_info > (3,): collect_ignore = ["test_django_channels.py"] - if sys.version_info < (3, 6): - collect_ignore.append('test_gevent.py') else: - collect_ignore = ["test_aiohttp.py"] + collect_ignore = ["test_aiohttp.py", "test_base_async.py"] diff --git a/tests/django_routing.py b/tests/django_routing.py new file mode 100644 index 0000000..9d01766 --- /dev/null +++ b/tests/django_routing.py @@ -0,0 +1,6 @@ +from channels.routing import route +from graphql_ws.django_channels import GraphQLSubscriptionConsumer + +channel_routing = [ + route("websocket.receive", GraphQLSubscriptionConsumer), +] diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index f20ca15..40c43fd 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -1,15 +1,22 @@ +try: + from aiohttp import WSMsgType + from graphql_ws.aiohttp import AiohttpConnectionContext, AiohttpSubscriptionServer +except ImportError: # pragma: no cover + WSMsgType = None + from unittest import mock import pytest -from aiohttp import WSMsgType -from graphql_ws.aiohttp import AiohttpConnectionContext, AiohttpSubscriptionServer from graphql_ws.base import ConnectionClosedException +if_aiohttp_installed = pytest.mark.skipif( + WSMsgType is None, reason="aiohttp is not installed" +) + class AsyncMock(mock.Mock): def __call__(self, *args, **kwargs): - async def coro(): return super(AsyncMock, self).__call__(*args, **kwargs) @@ -24,6 +31,7 @@ def mock_ws(): return ws +@if_aiohttp_installed @pytest.mark.asyncio class TestConnectionContext: async def test_receive_good_data(self, mock_ws): @@ -55,7 +63,7 @@ async def test_receive_closed(self, mock_ws): async def test_send(self, mock_ws): connection_context = AiohttpConnectionContext(ws=mock_ws) await connection_context.send("test") - mock_ws.send_str.assert_called_with("test") + mock_ws.send_str.assert_called_with('"test"') async def test_send_closed(self, mock_ws): mock_ws.closed = True @@ -69,5 +77,6 @@ async def test_close(self, mock_ws): mock_ws.close.assert_called_with(code=123) +@if_aiohttp_installed def test_subscription_server_smoke(): AiohttpSubscriptionServer(schema=None) diff --git a/tests/test_base.py b/tests/test_base.py new file mode 100644 index 0000000..1ce6300 --- /dev/null +++ b/tests/test_base.py @@ -0,0 +1,112 @@ +try: + from unittest import mock +except ImportError: + import mock + +import json + +import pytest + +from graphql_ws import base +from graphql_ws.base_sync import SubscriptionObserver + + +def test_not_implemented(): + server = base.BaseSubscriptionServer(schema=None) + with pytest.raises(NotImplementedError): + server.on_connection_init(connection_context=None, op_id=1, payload={}) + with pytest.raises(NotImplementedError): + server.on_open(connection_context=None) + + +def test_on_stop(): + server = base.BaseSubscriptionServer(schema=None) + context = mock.Mock() + server.on_stop(connection_context=context, op_id=1) + context.unsubscribe.assert_called_with(1) + + +def test_terminate(): + server = base.BaseSubscriptionServer(schema=None) + + context = mock.Mock() + server.on_connection_terminate(connection_context=context, op_id=1) + context.close.assert_called_with(1011) + + +def test_send_error(): + server = base.BaseSubscriptionServer(schema=None) + context = mock.Mock() + server.send_error(connection_context=context, op_id=1, error="test error") + context.send.assert_called_with( + {"id": 1, "type": "error", "payload": {"message": "test error"}} + ) + + +def test_message(): + server = base.BaseSubscriptionServer(schema=None) + server.process_message = mock.Mock() + context = mock.Mock() + msg = {"id": 1, "type": base.GQL_CONNECTION_INIT, "payload": ""} + server.on_message(context, msg) + server.process_message.assert_called_with(context, msg) + + +def test_message_str(): + server = base.BaseSubscriptionServer(schema=None) + server.process_message = mock.Mock() + context = mock.Mock() + msg = {"id": 1, "type": base.GQL_CONNECTION_INIT, "payload": ""} + server.on_message(context, json.dumps(msg)) + server.process_message.assert_called_with(context, msg) + + +def test_message_invalid(): + server = base.BaseSubscriptionServer(schema=None) + server.send_error = mock.Mock() + server.on_message(connection_context=None, message="'not-json") + assert server.send_error.called + + +def test_context_operations(): + ws = mock.Mock() + context = base.BaseConnectionContext(ws) + assert not context.has_operation(1) + context.register_operation(1, None) + assert context.has_operation(1) + context.remove_operation(1) + assert not context.has_operation(1) + # Removing a non-existant operation fails silently. + context.remove_operation(999) + + +def test_observer_data(): + ws = mock.Mock() + context = base.BaseConnectionContext(ws) + send_result, send_error, send_message = mock.Mock(), mock.Mock(), mock.Mock() + observer = SubscriptionObserver( + connection_context=context, + op_id=1, + send_execution_result=send_result, + send_error=send_error, + send_message=send_message, + ) + observer.on_next('data') + assert send_result.called + assert not send_error.called + + +def test_observer_exception(): + ws = mock.Mock() + context = base.BaseConnectionContext(ws) + send_result, send_error, send_message = mock.Mock(), mock.Mock(), mock.Mock() + observer = SubscriptionObserver( + connection_context=context, + op_id=1, + send_execution_result=send_result, + send_error=send_error, + send_message=send_message, + ) + observer.on_next(TypeError('some bad message')) + assert send_error.called + assert not send_result.called diff --git a/tests/test_base_async.py b/tests/test_base_async.py new file mode 100644 index 0000000..d62eda5 --- /dev/null +++ b/tests/test_base_async.py @@ -0,0 +1,100 @@ +from unittest import mock + +import json +import promise + +import pytest + +from graphql_ws import base, base_async + +pytestmark = pytest.mark.asyncio + + +class AsyncMock(mock.MagicMock): + async def __call__(self, *args, **kwargs): + return super().__call__(*args, **kwargs) + + +class TstServer(base_async.BaseAsyncSubscriptionServer): + def handle(self, *args, **kwargs): + pass # pragma: no cover + + +@pytest.fixture +def server(): + return TstServer(schema=None) + + +async def test_terminate(server: TstServer): + context = AsyncMock() + await server.on_connection_terminate(connection_context=context, op_id=1) + context.close.assert_called_with(1011) + + +async def test_send_error(server: TstServer): + context = AsyncMock() + context.has_operation = mock.Mock() + await server.send_error(connection_context=context, op_id=1, error="test error") + context.send.assert_called_with( + {"id": 1, "type": "error", "payload": {"message": "test error"}} + ) + + +async def test_message(server): + server.process_message = AsyncMock() + context = AsyncMock() + msg = {"id": 1, "type": base.GQL_CONNECTION_INIT, "payload": ""} + await server.on_message(context, msg) + server.process_message.assert_called_with(context, msg) + + +async def test_message_str(server): + server.process_message = AsyncMock() + context = AsyncMock() + msg = {"id": 1, "type": base.GQL_CONNECTION_INIT, "payload": ""} + await server.on_message(context, json.dumps(msg)) + server.process_message.assert_called_with(context, msg) + + +async def test_message_invalid(server): + server.send_error = AsyncMock() + await server.on_message(connection_context=None, message="'not-json") + assert server.send_error.called + + +async def test_resolver(server): + server.send_message = AsyncMock() + result = mock.Mock() + result.data = {"test": [1, 2]} + result.errors = None + await server.send_execution_result( + connection_context=None, op_id=1, execution_result=result + ) + assert server.send_message.called + + +@pytest.mark.asyncio +async def test_resolver_with_promise(server): + server.send_message = AsyncMock() + result = mock.Mock() + result.data = {"test": [1, promise.Promise(lambda resolve, reject: resolve(2))]} + result.errors = None + await server.send_execution_result( + connection_context=None, op_id=1, execution_result=result + ) + assert server.send_message.called + assert result.data == {'test': [1, 2]} + + +async def test_resolver_with_nested_promise(server): + server.send_message = AsyncMock() + result = mock.Mock() + inner = promise.Promise(lambda resolve, reject: resolve(2)) + outer = promise.Promise(lambda resolve, reject: resolve({'in': inner})) + result.data = {"test": [1, outer]} + result.errors = None + await server.send_execution_result( + connection_context=None, op_id=1, execution_result=result + ) + assert server.send_message.called + assert result.data == {'test': [1, {'in': 2}]} diff --git a/tests/test_django_channels.py b/tests/test_django_channels.py index e7b054c..0552c7b 100644 --- a/tests/test_django_channels.py +++ b/tests/test_django_channels.py @@ -1,11 +1,35 @@ +from __future__ import unicode_literals + +import json + +import django import mock +from channels import Channel +from channels.test import ChannelTestCase from django.conf import settings +from django.core.management import call_command -settings.configure() # noqa +settings.configure( + CHANNEL_LAYERS={ + "default": { + "BACKEND": "asgiref.inmemory.ChannelLayer", + "ROUTING": "tests.django_routing.channel_routing", + }, + }, + INSTALLED_APPS=[ + "django.contrib.sessions", + "django.contrib.contenttypes", + "django.contrib.auth", + ], + DATABASES={"default": {"ENGINE": "django.db.backends.sqlite3", "NAME": ":memory:"}}, +) +django.setup() -from graphql_ws.django_channels import ( +from graphql_ws.constants import GQL_CONNECTION_ACK, GQL_CONNECTION_INIT # noqa: E402 +from graphql_ws.django_channels import ( # noqa: E402 DjangoChannelConnectionContext, DjangoChannelSubscriptionServer, + GraphQLSubscriptionConsumer, ) @@ -14,7 +38,7 @@ def test_send(self): msg = mock.Mock() connection_context = DjangoChannelConnectionContext(message=msg) connection_context.send("test") - msg.reply_channel.send.assert_called_with("test") + msg.reply_channel.send.assert_called_with({"text": '"test"'}) def test_close(self): msg = mock.Mock() @@ -25,3 +49,21 @@ def test_close(self): def test_subscription_server_smoke(): DjangoChannelSubscriptionServer(schema=None) + + +class TestConsumer(ChannelTestCase): + def test_connect(self): + call_command("migrate") + Channel("websocket.receive").send( + { + "path": "/graphql", + "order": 0, + "reply_channel": "websocket.receive", + "text": json.dumps({"type": GQL_CONNECTION_INIT, "id": 1}), + } + ) + message = self.get_next_message("websocket.receive", require=True) + GraphQLSubscriptionConsumer(message) + result = self.get_next_message("websocket.receive", require=True) + result_content = json.loads(result.content["text"]) + assert result_content == {"type": GQL_CONNECTION_ACK} diff --git a/tests/test_gevent.py b/tests/test_gevent.py index f766c5a..a734970 100644 --- a/tests/test_gevent.py +++ b/tests/test_gevent.py @@ -17,8 +17,8 @@ def test_send(self): ws = mock.Mock() ws.closed = False connection_context = GeventConnectionContext(ws=ws) - connection_context.send("test") - ws.send.assert_called_with("test") + connection_context.send({"text": "test"}) + ws.send.assert_called_with('{"text": "test"}') def test_send_closed(self): ws = mock.Mock() diff --git a/tests/test_graphql_ws.py b/tests/test_graphql_ws.py index 3ba1120..3b85c49 100644 --- a/tests/test_graphql_ws.py +++ b/tests/test_graphql_ws.py @@ -1,12 +1,14 @@ from collections import OrderedDict + try: from unittest import mock except ImportError: import mock import pytest +from graphql.execution.executors.sync import SyncExecutor -from graphql_ws import base, constants +from graphql_ws import base, base_sync, constants @pytest.fixture @@ -18,7 +20,7 @@ def cc(): @pytest.fixture def ss(): - return base.BaseSubscriptionServer(schema=None) + return base_sync.BaseSyncSubscriptionServer(schema=None) class TestConnectionContextOperation: @@ -93,13 +95,13 @@ def test_start_existing_op(self, ss, cc): ss.get_graphql_params.return_value = {"params": True} cc.has_operation = mock.Mock() cc.has_operation.return_value = True - ss.unsubscribe = mock.Mock() - ss.on_start = mock.Mock() + cc.unsubscribe = mock.Mock() + ss.execute = mock.Mock() + ss.send_message = mock.Mock() ss.process_message( cc, {"id": "1", "type": constants.GQL_START, "payload": {"a": "b"}} ) - assert ss.unsubscribe.called - ss.on_start.assert_called_with(cc, "1", {"params": True}) + assert cc.unsubscribe.called def test_start_bad_graphql_params(self, ss, cc): ss.get_graphql_params = mock.Mock() @@ -109,9 +111,7 @@ def test_start_bad_graphql_params(self, ss, cc): ss.send_error = mock.Mock() ss.unsubscribe = mock.Mock() ss.on_start = mock.Mock() - ss.process_message( - cc, {"id": "1", "type": constants.GQL_START, "payload": {"a": "b"}} - ) + ss.process_message(cc, {"id": "1", "type": None, "payload": {"a": "b"}}) assert ss.send_error.called assert ss.send_error.call_args[0][:2] == (cc, "1") assert isinstance(ss.send_error.call_args[0][2], Exception) @@ -135,13 +135,15 @@ def test_get_graphql_params(ss, cc): "query": "req", "variables": "vars", "operationName": "query", - "context": "ctx", + "context": {}, } - assert ss.get_graphql_params(cc, payload) == { + params = ss.get_graphql_params(cc, payload) + assert isinstance(params.pop("executor"), SyncExecutor) + assert params == { "request_string": "req", "variable_values": "vars", "operation_name": "query", - "context_value": "ctx", + "context_value": {}, } @@ -159,7 +161,8 @@ def test_build_message_partial(ss): assert ss.build_message(id=None, op_type=None, payload="PAYLOAD") == { "payload": "PAYLOAD" } - assert ss.build_message(id=None, op_type=None, payload=None) == {} + with pytest.raises(AssertionError): + ss.build_message(id=None, op_type=None, payload=None) def test_send_execution_result(ss): @@ -189,34 +192,10 @@ def test_send_message(ss, cc): cc.send = mock.Mock() cc.send.return_value = "returned" assert "returned" == ss.send_message(cc) - cc.send.assert_called_with('{"mess": "age"}') + cc.send.assert_called_with({"mess": "age"}) class TestSSNotImplemented: def test_handle(self, ss): with pytest.raises(NotImplementedError): ss.handle(ws=None, request_context=None) - - def test_on_open(self, ss): - with pytest.raises(NotImplementedError): - ss.on_open(connection_context=None) - - def test_on_connect(self, ss): - with pytest.raises(NotImplementedError): - ss.on_connect(connection_context=None, payload=None) - - def test_on_close(self, ss): - with pytest.raises(NotImplementedError): - ss.on_close(connection_context=None) - - def test_on_connection_init(self, ss): - with pytest.raises(NotImplementedError): - ss.on_connection_init(connection_context=None, op_id=None, payload=None) - - def test_on_stop(self, ss): - with pytest.raises(NotImplementedError): - ss.on_stop(connection_context=None, op_id=None) - - def test_on_start(self, ss): - with pytest.raises(NotImplementedError): - ss.on_start(connection_context=None, op_id=None, params=None) diff --git a/tox.ini b/tox.ini index 6de6deb..62e2f8b 100644 --- a/tox.ini +++ b/tox.ini @@ -1,15 +1,15 @@ [tox] -envlist = +envlist = coverage_setup - py27, py35, py36, py37, py38, flake8 + py27, py36, py37, py38, py39, flake8 coverage_report [travis] python = - 3.8: py38, flake8 + 3.9: py39, flake8 + 3.8: py38 3.7: py37 3.6: py36 - 3.5: py35 2.7: py27 [testenv] @@ -31,5 +31,6 @@ skip_install = true deps = coverage commands = coverage html + coverage xml coverage report --include="tests/*" --fail-under=100 -m - coverage report --omit="tests/*" # --fail-under=90 -m \ No newline at end of file + coverage report --omit="tests/*" # --fail-under=90 -m