diff --git a/channels/generic/websocket.py b/channels/generic/websocket.py index 6c8ca576..9ce2657b 100644 --- a/channels/generic/websocket.py +++ b/channels/generic/websocket.py @@ -44,11 +44,15 @@ def websocket_connect(self, message): def connect(self): self.accept() - def accept(self, subprotocol=None): + def accept(self, subprotocol=None, headers=None): """ Accepts an incoming socket """ - super().send({"type": "websocket.accept", "subprotocol": subprotocol}) + message = {"type": "websocket.accept", "subprotocol": subprotocol} + if headers: + message["headers"] = list(headers) + + super().send(message) def websocket_receive(self, message): """ @@ -79,14 +83,16 @@ def send(self, text_data=None, bytes_data=None, close=False): if close: self.close(close) - def close(self, code=None): + def close(self, code=None, reason=None): """ Closes the WebSocket from the server end """ + message = {"type": "websocket.close"} if code is not None and code is not True: - super().send({"type": "websocket.close", "code": code}) - else: - super().send({"type": "websocket.close"}) + message["code"] = code + if reason: + message["reason"] = reason + super().send(message) def websocket_disconnect(self, message): """ @@ -179,11 +185,14 @@ async def websocket_connect(self, message): async def connect(self): await self.accept() - async def accept(self, subprotocol=None): + async def accept(self, subprotocol=None, headers=None): """ Accepts an incoming socket """ - await super().send({"type": "websocket.accept", "subprotocol": subprotocol}) + message = {"type": "websocket.accept", "subprotocol": subprotocol} + if headers: + message["headers"] = list(headers) + await super().send(message) async def websocket_receive(self, message): """ @@ -214,14 +223,16 @@ async def send(self, text_data=None, bytes_data=None, close=False): if close: await self.close(close) - async def close(self, code=None): + async def close(self, code=None, reason=None): """ Closes the WebSocket from the server end """ + message = {"type": "websocket.close"} if code is not None and code is not True: - await super().send({"type": "websocket.close", "code": code}) - else: - await super().send({"type": "websocket.close"}) + message["code"] = code + if reason: + message["reason"] = reason + await super().send(message) async def websocket_disconnect(self, message): """ diff --git a/channels/testing/websocket.py b/channels/testing/websocket.py index e1763214..dd48686d 100644 --- a/channels/testing/websocket.py +++ b/channels/testing/websocket.py @@ -12,7 +12,9 @@ class WebsocketCommunicator(ApplicationCommunicator): (uninstantiated) along with the initial connection parameters. """ - def __init__(self, application, path, headers=None, subprotocols=None): + def __init__( + self, application, path, headers=None, subprotocols=None, spec_version=None + ): if not isinstance(path, str): raise TypeError("Expected str, got {}".format(type(path))) parsed = urlparse(path) @@ -23,7 +25,10 @@ def __init__(self, application, path, headers=None, subprotocols=None): "headers": headers or [], "subprotocols": subprotocols or [], } + if spec_version: + self.scope["spec_version"] = spec_version super().__init__(application, self.scope) + self.response_headers = None async def connect(self, timeout=1): """ @@ -37,6 +42,8 @@ async def connect(self, timeout=1): if response["type"] == "websocket.close": return (False, response.get("code", 1000)) else: + assert response["type"] == "websocket.accept" + self.response_headers = response.get("headers", []) return (True, response.get("subprotocol", None)) async def send_to(self, text_data=None, bytes_data=None): diff --git a/tests/test_generic_websocket.py b/tests/test_generic_websocket.py index 10a761a9..73cdb486 100644 --- a/tests/test_generic_websocket.py +++ b/tests/test_generic_websocket.py @@ -424,3 +424,57 @@ async def _my_private_handler(self, _): ValueError, match=r"Malformed type in message \(leading underscore\)" ): await communicator.receive_from() + + +@pytest.mark.parametrize("async_consumer", [False, True]) +@pytest.mark.django_db +@pytest.mark.asyncio +async def test_accept_headers(async_consumer): + """ + Tests that JsonWebsocketConsumer is implemented correctly. + """ + + class TestConsumer(WebsocketConsumer): + def connect(self): + self.accept(headers=[[b"foo", b"bar"]]) + + class AsyncTestConsumer(AsyncWebsocketConsumer): + async def connect(self): + await self.accept(headers=[[b"foo", b"bar"]]) + + app = AsyncTestConsumer() if async_consumer else TestConsumer() + + # Open a connection + communicator = WebsocketCommunicator(app, "/testws/", spec_version="2.3") + connected, _ = await communicator.connect() + assert connected + assert communicator.response_headers == [[b"foo", b"bar"]] + + +@pytest.mark.parametrize("async_consumer", [False, True]) +@pytest.mark.django_db +@pytest.mark.asyncio +async def test_close_reason(async_consumer): + """ + Tests that JsonWebsocketConsumer is implemented correctly. + """ + + class TestConsumer(WebsocketConsumer): + def connect(self): + self.accept() + self.close(code=4007, reason="test reason") + + class AsyncTestConsumer(AsyncWebsocketConsumer): + async def connect(self): + await self.accept() + await self.close(code=4007, reason="test reason") + + app = AsyncTestConsumer() if async_consumer else TestConsumer() + + # Open a connection + communicator = WebsocketCommunicator(app, "/testws/", spec_version="2.3") + connected, _ = await communicator.connect() + msg = await communicator.receive_output() + assert msg["type"] == "websocket.close" + assert msg["code"] == 4007 + assert msg["reason"] == "test reason"