From c1c980e408b5b4b5d7a85da6715154f4ff1e0edb Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 18 Nov 2021 12:34:53 +0000 Subject: [PATCH] Network errors on HTTP/2 should propogate across all streams --- httpcore/_async/http2.py | 49 ++++++++++++++++++++++++++++++++++++---- httpcore/_sync/http2.py | 49 ++++++++++++++++++++++++++++++++++++---- 2 files changed, 88 insertions(+), 10 deletions(-) diff --git a/httpcore/_async/http2.py b/httpcore/_async/http2.py index 7f4d18d7..8f05eb48 100644 --- a/httpcore/_async/http2.py +++ b/httpcore/_async/http2.py @@ -52,7 +52,10 @@ def __init__( self._write_lock = AsyncLock() self._sent_connection_init = False self._used_all_stream_ids = False + self._connection_error = False self._events: typing.Dict[int, h2.events.Event] = {} + self._read_exception: typing.Optional[Exception] = None + self._write_exception: typing.Optional[Exception] = None async def handle_async_request(self, request: Request) -> Response: if not self.can_handle_request(request.url.origin): @@ -282,9 +285,26 @@ async def _read_incoming_data( timeouts = request.extensions.get("timeout", {}) timeout = timeouts.get("read", None) - data = await self._network_stream.read(self.READ_NUM_BYTES, timeout) - if data == b"": - raise RemoteProtocolError("Server disconnected") + if self._read_exception is not None: + raise self._read_exception # pragma: nocover + + try: + data = await self._network_stream.read(self.READ_NUM_BYTES, timeout) + if data == b"": + raise RemoteProtocolError("Server disconnected") + except Exception as exc: + # If we get a network error we should: + # + # 1. Save the exception and just raise it immediately on any future reads. + # (For example, this means that a single read timeout or disconnect will + # immediately close all pending streams. Without requiring multiple + # sequential timeouts.) + # 2. Mark the connection as errored, so that we don't accept any other + # incoming requests. + self._read_exception = exc + self._connection_error = True + raise exc + events = self._h2_state.receive_data(data) return events @@ -295,7 +315,24 @@ async def _write_outgoing_data(self, request: Request) -> None: async with self._write_lock: data_to_send = self._h2_state.data_to_send() - await self._network_stream.write(data_to_send, timeout) + + if self._write_exception is not None: + raise self._write_exception # pragma: nocover + + try: + await self._network_stream.write(data_to_send, timeout) + except Exception as exc: # pragma: nocover + # If we get a network error we should: + # + # 1. Save the exception and just raise it immediately on any future write. + # (For example, this means that a single write timeout or disconnect will + # immediately close all pending streams. Without requiring multiple + # sequential timeouts.) + # 2. Mark the connection as errored, so that we don't accept any other + # incoming requests. + self._write_exception = exc + self._connection_error = True + raise exc # Flow control... @@ -324,7 +361,9 @@ def can_handle_request(self, origin: Origin) -> bool: def is_available(self) -> bool: return ( - self._state != HTTPConnectionState.CLOSED and not self._used_all_stream_ids + self._state != HTTPConnectionState.CLOSED + and not self._connection_error + and not self._used_all_stream_ids ) def has_expired(self) -> bool: diff --git a/httpcore/_sync/http2.py b/httpcore/_sync/http2.py index 33125a1c..9976cfe3 100644 --- a/httpcore/_sync/http2.py +++ b/httpcore/_sync/http2.py @@ -52,7 +52,10 @@ def __init__( self._write_lock = Lock() self._sent_connection_init = False self._used_all_stream_ids = False + self._connection_error = False self._events: typing.Dict[int, h2.events.Event] = {} + self._read_exception: typing.Optional[Exception] = None + self._write_exception: typing.Optional[Exception] = None def handle_request(self, request: Request) -> Response: if not self.can_handle_request(request.url.origin): @@ -282,9 +285,26 @@ def _read_incoming_data( timeouts = request.extensions.get("timeout", {}) timeout = timeouts.get("read", None) - data = self._network_stream.read(self.READ_NUM_BYTES, timeout) - if data == b"": - raise RemoteProtocolError("Server disconnected") + if self._read_exception is not None: + raise self._read_exception # pragma: nocover + + try: + data = self._network_stream.read(self.READ_NUM_BYTES, timeout) + if data == b"": + raise RemoteProtocolError("Server disconnected") + except Exception as exc: + # If we get a network error we should: + # + # 1. Save the exception and just raise it immediately on any future reads. + # (For example, this means that a single read timeout or disconnect will + # immediately close all pending streams. Without requiring multiple + # sequential timeouts.) + # 2. Mark the connection as errored, so that we don't accept any other + # incoming requests. + self._read_exception = exc + self._connection_error = True + raise exc + events = self._h2_state.receive_data(data) return events @@ -295,7 +315,24 @@ def _write_outgoing_data(self, request: Request) -> None: with self._write_lock: data_to_send = self._h2_state.data_to_send() - self._network_stream.write(data_to_send, timeout) + + if self._write_exception is not None: + raise self._write_exception # pragma: nocover + + try: + self._network_stream.write(data_to_send, timeout) + except Exception as exc: # pragma: nocover + # If we get a network error we should: + # + # 1. Save the exception and just raise it immediately on any future write. + # (For example, this means that a single write timeout or disconnect will + # immediately close all pending streams. Without requiring multiple + # sequential timeouts.) + # 2. Mark the connection as errored, so that we don't accept any other + # incoming requests. + self._write_exception = exc + self._connection_error = True + raise exc # Flow control... @@ -324,7 +361,9 @@ def can_handle_request(self, origin: Origin) -> bool: def is_available(self) -> bool: return ( - self._state != HTTPConnectionState.CLOSED and not self._used_all_stream_ids + self._state != HTTPConnectionState.CLOSED + and not self._connection_error + and not self._used_all_stream_ids ) def has_expired(self) -> bool: