Skip to content

Commit

Permalink
Use SSL context for "wss" scheme (#869)
Browse files Browse the repository at this point in the history
* Add failing tests

* Use SSL context for "wss" scheme

* Update CHANGELOG.md

* Update CHANGELOG.md

---------

Co-authored-by: Tom Christie <tom@tomchristie.com>
  • Loading branch information
MtkN1 and tomchristie authored Jan 9, 2024
1 parent f60e99b commit b2de19e
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
## Unreleased

- Fix trace extension when used with socks proxy. (#849)
- Fix SSL context for connections using the "wss" scheme (#869)

## 1.0.2 (November 10th, 2023)

Expand Down
2 changes: 1 addition & 1 deletion httpcore/_async/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ async def _connect(self, request: Request) -> AsyncNetworkStream:
)
trace.return_value = stream

if self._origin.scheme == b"https":
if self._origin.scheme in (b"https", b"wss"):
ssl_context = (
default_ssl_context()
if self._ssl_context is None
Expand Down
2 changes: 1 addition & 1 deletion httpcore/_sync/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def _connect(self, request: Request) -> NetworkStream:
)
trace.return_value = stream

if self._origin.scheme == b"https":
if self._origin.scheme in (b"https", b"wss"):
ssl_context = (
default_ssl_context()
if self._ssl_context is None
Expand Down
22 changes: 22 additions & 0 deletions tests/_async/test_connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,15 +767,37 @@ async def test_http11_upgrade_connection():
b"...",
]
)

called = []

async def trace(name, kwargs):
called.append(name)

async with httpcore.AsyncConnectionPool(
network_backend=network_backend, max_connections=1
) as pool:
async with pool.stream(
"GET",
"wss://example.com/",
headers={"Connection": "upgrade", "Upgrade": "custom"},
extensions={"trace": trace},
) as response:
assert response.status == 101
network_stream = response.extensions["network_stream"]
content = await network_stream.read(max_bytes=1024)
assert content == b"..."

assert called == [
"connection.connect_tcp.started",
"connection.connect_tcp.complete",
"connection.start_tls.started",
"connection.start_tls.complete",
"http11.send_request_headers.started",
"http11.send_request_headers.complete",
"http11.send_request_body.started",
"http11.send_request_body.complete",
"http11.receive_response_headers.started",
"http11.receive_response_headers.complete",
"http11.response_closed.started",
"http11.response_closed.complete",
]
22 changes: 22 additions & 0 deletions tests/_sync/test_connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,15 +767,37 @@ def test_http11_upgrade_connection():
b"...",
]
)

called = []

def trace(name, kwargs):
called.append(name)

with httpcore.ConnectionPool(
network_backend=network_backend, max_connections=1
) as pool:
with pool.stream(
"GET",
"wss://example.com/",
headers={"Connection": "upgrade", "Upgrade": "custom"},
extensions={"trace": trace},
) as response:
assert response.status == 101
network_stream = response.extensions["network_stream"]
content = network_stream.read(max_bytes=1024)
assert content == b"..."

assert called == [
"connection.connect_tcp.started",
"connection.connect_tcp.complete",
"connection.start_tls.started",
"connection.start_tls.complete",
"http11.send_request_headers.started",
"http11.send_request_headers.complete",
"http11.send_request_body.started",
"http11.send_request_body.complete",
"http11.receive_response_headers.started",
"http11.receive_response_headers.complete",
"http11.response_closed.started",
"http11.response_closed.complete",
]

0 comments on commit b2de19e

Please sign in to comment.