Skip to content

Commit

Permalink
Check disconnections on connection reacquiry (#145)
Browse files Browse the repository at this point in the history
* Detect EOF signaling remote server closed connection

Raise ConnectionClosedByRemote and handle on `send`

* Fix linting

* Use existing NotConnected exception

* Add `Reader.is_connection_dropped` method

* Check connection before sending h11 events as well

* Add test covering connection lost before reading response content

* Check for connection closed on acquiring it from the pool

* Clean up ConnectionPool logic around reaquiry of connections
  • Loading branch information
tomchristie authored Jul 25, 2019
1 parent 8db36ed commit ec365c0
Show file tree
Hide file tree
Showing 11 changed files with 78 additions and 51 deletions.
3 changes: 3 additions & 0 deletions httpx/concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ async def read(

return data

def is_connection_dropped(self) -> bool:
return self.stream_reader.at_eof()


class Writer(BaseWriter):
def __init__(self, stream_writer: asyncio.StreamWriter, timeout: TimeoutConfig):
Expand Down
7 changes: 7 additions & 0 deletions httpx/dispatch/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,10 @@ def is_closed(self) -> bool:
else:
assert self.h11_connection is not None
return self.h11_connection.is_closed

def is_connection_dropped(self) -> bool:
if self.h2_connection is not None:
return self.h2_connection.is_connection_dropped()
else:
assert self.h11_connection is not None
return self.h11_connection.is_connection_dropped()
41 changes: 15 additions & 26 deletions httpx/dispatch/connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
TimeoutTypes,
VerifyTypes,
)
from ..exceptions import NotConnected
from ..interfaces import AsyncDispatcher, ConcurrencyBackend
from ..models import AsyncRequest, AsyncResponse, Origin
from .connection import HTTPConnection
Expand Down Expand Up @@ -108,35 +107,25 @@ async def send(
cert: CertTypes = None,
timeout: TimeoutTypes = None,
) -> AsyncResponse:
allow_connection_reuse = True
connection = None
while connection is None:
connection = await self.acquire_connection(
origin=request.url.origin, allow_connection_reuse=allow_connection_reuse
connection = await self.acquire_connection(origin=request.url.origin)
try:
response = await connection.send(
request, verify=verify, cert=cert, timeout=timeout
)
try:
response = await connection.send(
request, verify=verify, cert=cert, timeout=timeout
)
except BaseException as exc:
self.active_connections.remove(connection)
self.max_connections.release()
if isinstance(exc, NotConnected) and allow_connection_reuse:
connection = None
allow_connection_reuse = False
else:
raise exc
except BaseException as exc:
self.active_connections.remove(connection)
self.max_connections.release()
raise exc

return response

async def acquire_connection(
self, origin: Origin, allow_connection_reuse: bool = True
) -> HTTPConnection:
connection = None
if allow_connection_reuse:
connection = self.active_connections.pop_by_origin(origin, http2_only=True)
if connection is None:
connection = self.keepalive_connections.pop_by_origin(origin)
async def acquire_connection(self, origin: Origin) -> HTTPConnection:
connection = self.active_connections.pop_by_origin(origin, http2_only=True)
if connection is None:
connection = self.keepalive_connections.pop_by_origin(origin)

if connection is not None and connection.is_connection_dropped():
connection = None

if connection is None:
await self.max_connections.acquire()
Expand Down
11 changes: 4 additions & 7 deletions httpx/dispatch/http11.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from ..concurrency import TimeoutFlag
from ..config import TimeoutConfig, TimeoutTypes
from ..exceptions import NotConnected
from ..interfaces import BaseReader, BaseWriter, ConcurrencyBackend
from ..models import AsyncRequest, AsyncResponse

Expand Down Expand Up @@ -46,12 +45,7 @@ async def send(
) -> AsyncResponse:
timeout = None if timeout is None else TimeoutConfig(timeout)

try:
await self._send_request(request, timeout)
except ConnectionResetError: # pragma: nocover
# We're currently testing this case in HTTP/2.
# Really we should test it here too, but this'll do in the meantime.
raise NotConnected() from None
await self._send_request(request, timeout)

task, args = self._send_request_data, [request.stream(), timeout]
async with self.backend.background_manager(task, args=args):
Expand Down Expand Up @@ -188,3 +182,6 @@ async def response_closed(self) -> None:
@property
def is_closed(self) -> bool:
return self.h11_state.our_state in (h11.CLOSED, h11.ERROR)

def is_connection_dropped(self) -> bool:
return self.reader.is_connection_dropped()
9 changes: 4 additions & 5 deletions httpx/dispatch/http2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from ..concurrency import TimeoutFlag
from ..config import TimeoutConfig, TimeoutTypes
from ..exceptions import NotConnected
from ..interfaces import BaseReader, BaseWriter, ConcurrencyBackend
from ..models import AsyncRequest, AsyncResponse

Expand Down Expand Up @@ -39,10 +38,7 @@ async def send(
if not self.initialized:
self.initiate_connection()

try:
stream_id = await self.send_headers(request, timeout)
except ConnectionResetError:
raise NotConnected() from None
stream_id = await self.send_headers(request, timeout)

self.events[stream_id] = []
self.timeout_flags[stream_id] = TimeoutFlag()
Expand Down Expand Up @@ -176,3 +172,6 @@ async def response_closed(self, stream_id: int) -> None:
@property
def is_closed(self) -> bool:
return False

def is_connection_dropped(self) -> bool:
return self.reader.is_connection_dropped()
7 changes: 0 additions & 7 deletions httpx/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,6 @@ class PoolTimeout(Timeout):
# HTTP exceptions...


class NotConnected(Exception):
"""
A connection was lost at the point of starting a request,
prior to any writes succeeding.
"""


class HttpError(Exception):
"""
An HTTP error occurred.
Expand Down
3 changes: 3 additions & 0 deletions httpx/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ async def read(
) -> bytes:
raise NotImplementedError() # pragma: no cover

def is_connection_dropped(self) -> bool:
raise NotImplementedError() # pragma: no cover


class BaseWriter:
"""
Expand Down
36 changes: 36 additions & 0 deletions tests/dispatch/test_connection_pools.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,39 @@ async def test_premature_response_close(server):
await response.close()
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 0


@pytest.mark.asyncio
async def test_keepalive_connection_closed_by_server_is_reestablished(server):
"""
Upon keep-alive connection closed by remote a new connection should be reestablished.
"""
async with httpx.ConnectionPool() as http:
response = await http.request("GET", "http://127.0.0.1:8000/")
await response.read()

await server.shutdown() # shutdown the server to close the keep-alive connection
await server.startup()

response = await http.request("GET", "http://127.0.0.1:8000/")
await response.read()
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 1


@pytest.mark.asyncio
async def test_keepalive_http2_connection_closed_by_server_is_reestablished(server):
"""
Upon keep-alive connection closed by remote a new connection should be reestablished.
"""
async with httpx.ConnectionPool() as http:
response = await http.request("GET", "http://127.0.0.1:8000/")
await response.read()

await server.shutdown() # shutdown the server to close the keep-alive connection
await server.startup()

response = await http.request("GET", "http://127.0.0.1:8000/")
await response.read()
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 1
2 changes: 1 addition & 1 deletion tests/dispatch/test_http2.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def test_http2_reconnect():

with Client(backend=backend) as client:
response_1 = client.get("http://example.org/1")
backend.server.raise_disconnect = True
backend.server.close_connection = True
response_2 = client.get("http://example.org/2")

assert response_1.status_code == 200
Expand Down
8 changes: 4 additions & 4 deletions tests/dispatch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(self, app):
self.app = app
self.buffer = b""
self.requests = {}
self.raise_disconnect = False
self.close_connection = False

# BaseReader interface

Expand All @@ -55,9 +55,6 @@ async def read(self, n, timeout, flag=None) -> bytes:
# BaseWriter interface

def write_no_block(self, data: bytes) -> None:
if self.raise_disconnect:
self.raise_disconnect = False
raise ConnectionResetError()
events = self.conn.receive_data(data)
self.buffer += self.conn.data_to_send()
for event in events:
Expand All @@ -74,6 +71,9 @@ async def write(self, data: bytes, timeout) -> None:
async def close(self) -> None:
pass

def is_connection_dropped(self) -> bool:
return self.close_connection

# Server implementation

def request_received(self, headers, stream_id):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
CertTypes,
Client,
Dispatcher,
multipart,
Request,
Response,
TimeoutTypes,
VerifyTypes,
multipart,
)


Expand Down

0 comments on commit ec365c0

Please sign in to comment.