From b2de19e594b0ef89521afaa5e6ea9f33cabf3f67 Mon Sep 17 00:00:00 2001 From: MtkN1 <51289448+MtkN1@users.noreply.github.com> Date: Wed, 10 Jan 2024 00:25:44 +0900 Subject: [PATCH] Use SSL context for "wss" scheme (#869) * Add failing tests * Use SSL context for "wss" scheme * Update CHANGELOG.md * Update CHANGELOG.md --------- Co-authored-by: Tom Christie --- CHANGELOG.md | 1 + httpcore/_async/connection.py | 2 +- httpcore/_sync/connection.py | 2 +- tests/_async/test_connection_pool.py | 22 ++++++++++++++++++++++ tests/_sync/test_connection_pool.py | 22 ++++++++++++++++++++++ 5 files changed, 47 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d42bc585..061358f4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ## Unreleased - Fix trace extension when used with socks proxy. (#849) +- Fix SSL context for connections using the "wss" scheme (#869) ## 1.0.2 (November 10th, 2023) diff --git a/httpcore/_async/connection.py b/httpcore/_async/connection.py index 45ee22a6..3aeb8ed9 100644 --- a/httpcore/_async/connection.py +++ b/httpcore/_async/connection.py @@ -137,7 +137,7 @@ async def _connect(self, request: Request) -> AsyncNetworkStream: ) trace.return_value = stream - if self._origin.scheme == b"https": + if self._origin.scheme in (b"https", b"wss"): ssl_context = ( default_ssl_context() if self._ssl_context is None diff --git a/httpcore/_sync/connection.py b/httpcore/_sync/connection.py index 81e4172a..f6b99f1b 100644 --- a/httpcore/_sync/connection.py +++ b/httpcore/_sync/connection.py @@ -137,7 +137,7 @@ def _connect(self, request: Request) -> NetworkStream: ) trace.return_value = stream - if self._origin.scheme == b"https": + if self._origin.scheme in (b"https", b"wss"): ssl_context = ( default_ssl_context() if self._ssl_context is None diff --git a/tests/_async/test_connection_pool.py b/tests/_async/test_connection_pool.py index 2392ca17..61ee1e54 100644 --- a/tests/_async/test_connection_pool.py +++ b/tests/_async/test_connection_pool.py @@ -767,6 +767,12 @@ async def test_http11_upgrade_connection(): b"...", ] ) + + called = [] + + async def trace(name, kwargs): + called.append(name) + async with httpcore.AsyncConnectionPool( network_backend=network_backend, max_connections=1 ) as pool: @@ -774,8 +780,24 @@ async def test_http11_upgrade_connection(): "GET", "wss://example.com/", headers={"Connection": "upgrade", "Upgrade": "custom"}, + extensions={"trace": trace}, ) as response: assert response.status == 101 network_stream = response.extensions["network_stream"] content = await network_stream.read(max_bytes=1024) assert content == b"..." + + assert called == [ + "connection.connect_tcp.started", + "connection.connect_tcp.complete", + "connection.start_tls.started", + "connection.start_tls.complete", + "http11.send_request_headers.started", + "http11.send_request_headers.complete", + "http11.send_request_body.started", + "http11.send_request_body.complete", + "http11.receive_response_headers.started", + "http11.receive_response_headers.complete", + "http11.response_closed.started", + "http11.response_closed.complete", + ] diff --git a/tests/_sync/test_connection_pool.py b/tests/_sync/test_connection_pool.py index 287c2bcc..c9621c7b 100644 --- a/tests/_sync/test_connection_pool.py +++ b/tests/_sync/test_connection_pool.py @@ -767,6 +767,12 @@ def test_http11_upgrade_connection(): b"...", ] ) + + called = [] + + def trace(name, kwargs): + called.append(name) + with httpcore.ConnectionPool( network_backend=network_backend, max_connections=1 ) as pool: @@ -774,8 +780,24 @@ def test_http11_upgrade_connection(): "GET", "wss://example.com/", headers={"Connection": "upgrade", "Upgrade": "custom"}, + extensions={"trace": trace}, ) as response: assert response.status == 101 network_stream = response.extensions["network_stream"] content = network_stream.read(max_bytes=1024) assert content == b"..." + + assert called == [ + "connection.connect_tcp.started", + "connection.connect_tcp.complete", + "connection.start_tls.started", + "connection.start_tls.complete", + "http11.send_request_headers.started", + "http11.send_request_headers.complete", + "http11.send_request_body.started", + "http11.send_request_body.complete", + "http11.receive_response_headers.started", + "http11.receive_response_headers.complete", + "http11.response_closed.started", + "http11.response_closed.complete", + ]