Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ConnectionResetError not being raised when the transport is close… #7198

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES/7180.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
``ConnectionResetError`` will always be raised when ``StreamWriter.write`` is called after ``connection_lost`` has been called on the ``BaseProtocol``
1 change: 1 addition & 0 deletions CONTRIBUTORS.txt
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ Ilya Gruzinov
Ingmar Steen
Ivan Lakovic
Ivan Larin
J. Nick Koston
Jacob Champion
Jaesung Lee
Jake Davis
Expand Down
9 changes: 6 additions & 3 deletions aiohttp/base_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,15 @@ def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
self._loop: asyncio.AbstractEventLoop = loop
self._paused = False
self._drain_waiter: Optional[asyncio.Future[None]] = None
self._connection_lost = False
self._reading_paused = False

self.transport: Optional[asyncio.Transport] = None

@property
def connected(self) -> bool:
"""Return True if the connection is open."""
return self.transport is not None

def pause_writing(self) -> None:
assert not self._paused
self._paused = True
Expand Down Expand Up @@ -59,7 +63,6 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None:
self.transport = tr

def connection_lost(self, exc: Optional[BaseException]) -> None:
self._connection_lost = True
# Wake up the writer if currently paused.
self.transport = None
if not self._paused:
Expand All @@ -76,7 +79,7 @@ def connection_lost(self, exc: Optional[BaseException]) -> None:
waiter.set_exception(exc)

async def _drain_helper(self) -> None:
if self._connection_lost:
if not self.connected:
raise ConnectionResetError("Connection lost")
if not self._paused:
return
Expand Down
10 changes: 4 additions & 6 deletions aiohttp/http_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def __init__(
on_headers_sent: _T_OnHeadersSent = None,
) -> None:
self._protocol = protocol
self._transport = protocol.transport

self.loop = loop
self.length = None
Expand All @@ -52,7 +51,7 @@ def __init__(

@property
def transport(self) -> Optional[asyncio.Transport]:
return self._transport
return self._protocol.transport

@property
def protocol(self) -> BaseProtocol:
Expand All @@ -71,10 +70,10 @@ def _write(self, chunk: bytes) -> None:
size = len(chunk)
self.buffer_size += size
self.output_size += size

if self._transport is None or self._transport.is_closing():
transport = self.transport
if not self._protocol.connected or transport is None or transport.is_closing():
raise ConnectionResetError("Cannot write to closing transport")
self._transport.write(chunk)
transport.write(chunk)

async def write(
self, chunk: bytes, *, drain: bool = True, LIMIT: int = 0x10000
Expand Down Expand Up @@ -159,7 +158,6 @@ async def write_eof(self, chunk: bytes = b"") -> None:
await self.drain()

self._eof = True
self._transport = None

async def drain(self) -> None:
"""Flush the write buffer.
Expand Down
8 changes: 4 additions & 4 deletions tests/test_base_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,22 +45,22 @@ async def test_connection_lost_not_paused() -> None:
pr = BaseProtocol(loop=loop)
tr = mock.Mock()
pr.connection_made(tr)
assert not pr._connection_lost
assert pr.connected
pr.connection_lost(None)
assert pr.transport is None
assert pr._connection_lost
assert not pr.connected


async def test_connection_lost_paused_without_waiter() -> None:
loop = asyncio.get_event_loop()
pr = BaseProtocol(loop=loop)
tr = mock.Mock()
pr.connection_made(tr)
assert not pr._connection_lost
assert pr.connected
pr.pause_writing()
pr.connection_lost(None)
assert pr.transport is None
assert pr._connection_lost
assert not pr.connected


async def test_drain_lost() -> None:
Expand Down
14 changes: 14 additions & 0 deletions tests/test_client_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,17 @@ async def test_eof_received(loop) -> None:
assert proto._read_timeout_handle is not None
proto.eof_received()
assert proto._read_timeout_handle is None


async def test_connection_lost_sets_transport_to_none(loop, mocker) -> None:
"""Ensure that the transport is set to None when the connection is lost.

This ensures the writer knows that the connection is closed.
"""
proto = ResponseHandler(loop=loop)
proto.connection_made(mocker.Mock())
assert proto.transport is not None

proto.connection_lost(OSError())

assert proto.transport is None
15 changes: 15 additions & 0 deletions tests/test_http_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,21 @@ async def test_write_to_closing_transport(protocol, transport, loop) -> None:
await msg.write(b"After closing")


async def test_write_to_closed_transport(protocol, transport, loop) -> None:
"""Test that writing to a closed transport raises ConnectionResetError.

The StreamWriter checks to see if protocol.transport is None before
writing to the transport. If it is None, it raises ConnectionResetError.
"""
msg = http.StreamWriter(protocol, loop)

await msg.write(b"Before transport close")
protocol.transport = None

with pytest.raises(ConnectionResetError, match="Cannot write to closing transport"):
await msg.write(b"After transport closed")


async def test_drain(protocol, transport, loop) -> None:
msg = http.StreamWriter(protocol, loop)
await msg.drain()
Expand Down