Skip to content

Commit

Permalink
Updated websocket consumers for newer ASGI spec versions. (#2002)
Browse files Browse the repository at this point in the history
Adds `headers` and `reason` args to accept and close events respectively.
  • Loading branch information
kristjanvalur authored Apr 3, 2024
1 parent 82c26f2 commit de88e03
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 13 deletions.
35 changes: 23 additions & 12 deletions channels/generic/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,15 @@ def websocket_connect(self, message):
def connect(self):
self.accept()

def accept(self, subprotocol=None):
def accept(self, subprotocol=None, headers=None):
"""
Accepts an incoming socket
"""
super().send({"type": "websocket.accept", "subprotocol": subprotocol})
message = {"type": "websocket.accept", "subprotocol": subprotocol}
if headers:
message["headers"] = list(headers)

super().send(message)

def websocket_receive(self, message):
"""
Expand Down Expand Up @@ -79,14 +83,16 @@ def send(self, text_data=None, bytes_data=None, close=False):
if close:
self.close(close)

def close(self, code=None):
def close(self, code=None, reason=None):
"""
Closes the WebSocket from the server end
"""
message = {"type": "websocket.close"}
if code is not None and code is not True:
super().send({"type": "websocket.close", "code": code})
else:
super().send({"type": "websocket.close"})
message["code"] = code
if reason:
message["reason"] = reason
super().send(message)

def websocket_disconnect(self, message):
"""
Expand Down Expand Up @@ -179,11 +185,14 @@ async def websocket_connect(self, message):
async def connect(self):
await self.accept()

async def accept(self, subprotocol=None):
async def accept(self, subprotocol=None, headers=None):
"""
Accepts an incoming socket
"""
await super().send({"type": "websocket.accept", "subprotocol": subprotocol})
message = {"type": "websocket.accept", "subprotocol": subprotocol}
if headers:
message["headers"] = list(headers)
await super().send(message)

async def websocket_receive(self, message):
"""
Expand Down Expand Up @@ -214,14 +223,16 @@ async def send(self, text_data=None, bytes_data=None, close=False):
if close:
await self.close(close)

async def close(self, code=None):
async def close(self, code=None, reason=None):
"""
Closes the WebSocket from the server end
"""
message = {"type": "websocket.close"}
if code is not None and code is not True:
await super().send({"type": "websocket.close", "code": code})
else:
await super().send({"type": "websocket.close"})
message["code"] = code
if reason:
message["reason"] = reason
await super().send(message)

async def websocket_disconnect(self, message):
"""
Expand Down
9 changes: 8 additions & 1 deletion channels/testing/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ class WebsocketCommunicator(ApplicationCommunicator):
(uninstantiated) along with the initial connection parameters.
"""

def __init__(self, application, path, headers=None, subprotocols=None):
def __init__(
self, application, path, headers=None, subprotocols=None, spec_version=None
):
if not isinstance(path, str):
raise TypeError("Expected str, got {}".format(type(path)))
parsed = urlparse(path)
Expand All @@ -23,7 +25,10 @@ def __init__(self, application, path, headers=None, subprotocols=None):
"headers": headers or [],
"subprotocols": subprotocols or [],
}
if spec_version:
self.scope["spec_version"] = spec_version
super().__init__(application, self.scope)
self.response_headers = None

async def connect(self, timeout=1):
"""
Expand All @@ -37,6 +42,8 @@ async def connect(self, timeout=1):
if response["type"] == "websocket.close":
return (False, response.get("code", 1000))
else:
assert response["type"] == "websocket.accept"
self.response_headers = response.get("headers", [])
return (True, response.get("subprotocol", None))

async def send_to(self, text_data=None, bytes_data=None):
Expand Down
54 changes: 54 additions & 0 deletions tests/test_generic_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,3 +424,57 @@ async def _my_private_handler(self, _):
ValueError, match=r"Malformed type in message \(leading underscore\)"
):
await communicator.receive_from()


@pytest.mark.parametrize("async_consumer", [False, True])
@pytest.mark.django_db
@pytest.mark.asyncio
async def test_accept_headers(async_consumer):
"""
Tests that JsonWebsocketConsumer is implemented correctly.
"""

class TestConsumer(WebsocketConsumer):
def connect(self):
self.accept(headers=[[b"foo", b"bar"]])

class AsyncTestConsumer(AsyncWebsocketConsumer):
async def connect(self):
await self.accept(headers=[[b"foo", b"bar"]])

app = AsyncTestConsumer() if async_consumer else TestConsumer()

# Open a connection
communicator = WebsocketCommunicator(app, "/testws/", spec_version="2.3")
connected, _ = await communicator.connect()
assert connected
assert communicator.response_headers == [[b"foo", b"bar"]]


@pytest.mark.parametrize("async_consumer", [False, True])
@pytest.mark.django_db
@pytest.mark.asyncio
async def test_close_reason(async_consumer):
"""
Tests that JsonWebsocketConsumer is implemented correctly.
"""

class TestConsumer(WebsocketConsumer):
def connect(self):
self.accept()
self.close(code=4007, reason="test reason")

class AsyncTestConsumer(AsyncWebsocketConsumer):
async def connect(self):
await self.accept()
await self.close(code=4007, reason="test reason")

app = AsyncTestConsumer() if async_consumer else TestConsumer()

# Open a connection
communicator = WebsocketCommunicator(app, "/testws/", spec_version="2.3")
connected, _ = await communicator.connect()
msg = await communicator.receive_output()
assert msg["type"] == "websocket.close"
assert msg["code"] == 4007
assert msg["reason"] == "test reason"

0 comments on commit de88e03

Please sign in to comment.