Skip to content

Commit

Permalink
Migrate graphql-transport-ws types to TypedDicts (#3701)
Browse files Browse the repository at this point in the history
* Migrate graphql-transport-ws types to TypedDicts

* Add release file

* Fix quotes in f-strings

* Make more use of type inference

* Make more use of type inference in tests too
  • Loading branch information
DoctorJohn authored Nov 18, 2024
1 parent 275ebc1 commit efc736f
Show file tree
Hide file tree
Showing 9 changed files with 979 additions and 880 deletions.
5 changes: 5 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Release type: minor

In this release, we migrated the `graphql-transport-ws` types from data classes to typed dicts.
Using typed dicts enabled us to precisely model `null` versus `undefined` values, which are common in that protocol.
As a result, we could remove custom conversion methods handling these cases and simplify the codebase.
59 changes: 26 additions & 33 deletions strawberry/channels/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,10 @@

from channels.testing.websocket import WebsocketCommunicator
from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL
from strawberry.subscriptions.protocols.graphql_transport_ws.types import (
ConnectionAckMessage,
ConnectionInitMessage,
ErrorMessage,
NextMessage,
SubscribeMessage,
SubscribeMessagePayload,
)
from strawberry.subscriptions.protocols.graphql_ws.types import (
ConnectionAckMessage as GraphQLWSConnectionAckMessage,
)
from strawberry.subscriptions.protocols.graphql_ws.types import (
ConnectionInitMessage as GraphQLWSConnectionInitMessage,
)
from strawberry.subscriptions.protocols.graphql_ws.types import (
StartMessage as GraphQLWSStartMessage,
from strawberry.subscriptions.protocols.graphql_transport_ws import (
types as transport_ws_types,
)
from strawberry.subscriptions.protocols.graphql_ws import types as ws_types
from strawberry.types import ExecutionResult

if TYPE_CHECKING:
Expand Down Expand Up @@ -109,19 +96,21 @@ async def gql_init(self) -> None:
if self.protocol == GRAPHQL_TRANSPORT_WS_PROTOCOL:
assert res == (True, GRAPHQL_TRANSPORT_WS_PROTOCOL)
await self.send_json_to(
ConnectionInitMessage(payload=self.connection_params).as_dict()
transport_ws_types.ConnectionInitMessage(
{"type": "connection_init", "payload": self.connection_params}
)
)
graphql_transport_ws_response = await self.receive_json_from()
assert graphql_transport_ws_response == ConnectionAckMessage().as_dict()
transport_ws_connection_ack_message: transport_ws_types.ConnectionAckMessage = await self.receive_json_from()
assert transport_ws_connection_ack_message == {"type": "connection_ack"}
else:
assert res == (True, GRAPHQL_WS_PROTOCOL)
await self.send_json_to(
GraphQLWSConnectionInitMessage({"type": "connection_init"})
ws_types.ConnectionInitMessage({"type": "connection_init"})
)
graphql_ws_response: GraphQLWSConnectionAckMessage = (
ws_connection_ack_message: ws_types.ConnectionAckMessage = (
await self.receive_json_from()
)
assert graphql_ws_response["type"] == "connection_ack"
assert ws_connection_ack_message["type"] == "connection_ack"

# Actual `ExecutionResult`` objects are not available client-side, since they
# get transformed into `FormattedExecutionResult` on the wire, but we attempt
Expand All @@ -133,13 +122,16 @@ async def subscribe(

if self.protocol == GRAPHQL_TRANSPORT_WS_PROTOCOL:
await self.send_json_to(
SubscribeMessage(
id=id_,
payload=SubscribeMessagePayload(query=query, variables=variables),
).as_dict()
transport_ws_types.SubscribeMessage(
{
"id": id_,
"type": "subscribe",
"payload": {"query": query, "variables": variables},
}
)
)
else:
start_message: GraphQLWSStartMessage = {
start_message: ws_types.StartMessage = {
"type": "start",
"id": id_,
"payload": {
Expand All @@ -153,17 +145,18 @@ async def subscribe(
await self.send_json_to(start_message)

while True:
response = await self.receive_json_from(timeout=5)
message_type = response["type"]
if message_type == NextMessage.type:
payload = NextMessage(**response).payload
message: transport_ws_types.Message = await self.receive_json_from(
timeout=5
)
if message["type"] == "next":
payload = message["payload"]
ret = ExecutionResult(payload.get("data"), None)
if "errors" in payload:
ret.errors = self.process_errors(payload.get("errors") or [])
ret.extensions = payload.get("extensions", None)
yield ret
elif message_type == ErrorMessage.type:
error_payload = ErrorMessage(**response).payload
elif message["type"] == "error":
error_payload = message["payload"]
yield ExecutionResult(
data=None, errors=self.process_errors(error_payload)
)
Expand Down
Loading

0 comments on commit efc736f

Please sign in to comment.