Skip to content

Commit

Permalink
Support extra_headers for WS accept message
Browse files Browse the repository at this point in the history
  • Loading branch information
matiuszka committed Dec 15, 2021
1 parent 65ec8d1 commit 799bce2
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 1 deletion.
20 changes: 20 additions & 0 deletions tests/protocols/test_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,26 @@ async def open_connection(url):
assert is_open


@pytest.mark.asyncio
@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS)
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_extra_headers(ws_protocol_cls, http_protocol_cls):
class App(WebSocketResponse):
async def websocket_connect(self, message):
await self.send(
{"type": "websocket.accept", "headers": [(b"extra", b"header")]}
)

async def open_connection(url):
async with websockets.connect(url) as websocket:
return websocket.response_headers

config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off")
async with run_server(config):
extra_headers = await open_connection("ws://127.0.0.1:8000")
assert extra_headers.get("extra") == "header"


@pytest.mark.asyncio
@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS)
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
Expand Down
7 changes: 7 additions & 0 deletions uvicorn/protocols/websockets/websockets_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def __init__(
ping_timeout=self.config.ws_ping_timeout,
extensions=[ServerPerMessageDeflateFactory()],
logger=logging.getLogger("uvicorn.error"),
extra_headers=[],
)

def connection_made(self, transport):
Expand Down Expand Up @@ -231,6 +232,12 @@ async def asgi_send(self, message):
)
self.initial_response = None
self.accepted_subprotocol = message.get("subprotocol")
self.extra_headers.extend(
# ASGI spec requires bytes
# But for compability we need to convert it to strings
(name.decode(), value.decode())
for name, value in message.get("headers")
)
self.handshake_started_event.set()

elif message_type == "websocket.close":
Expand Down
5 changes: 4 additions & 1 deletion uvicorn/protocols/websockets/wsproto_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,9 +257,12 @@ async def send(self, message):
)
self.handshake_complete = True
subprotocol = message.get("subprotocol")
extra_headers = message.get("headers", [])
output = self.conn.send(
wsproto.events.AcceptConnection(
subprotocol=subprotocol, extensions=[PerMessageDeflate()]
subprotocol=subprotocol,
extensions=[PerMessageDeflate()],
extra_headers=extra_headers,
)
)
self.transport.write(output)
Expand Down

0 comments on commit 799bce2

Please sign in to comment.