Skip to content

Commit

Permalink
[PR #8699/11f0e7f backport][3.11] Reduce code indent in ResponseHandl…
Browse files Browse the repository at this point in the history
…er.data_received (#10056)
  • Loading branch information
bdraco authored Nov 27, 2024
1 parent 1a6fafe commit 7e628f4
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 54 deletions.
105 changes: 52 additions & 53 deletions aiohttp/client_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def data_received(self, data: bytes) -> None:
if not data:
return

# custom payload parser
# custom payload parser - currently always WebSocketReader
if self._payload_parser is not None:
eof, tail = self._payload_parser.feed_data(data)
if eof:
Expand All @@ -252,57 +252,56 @@ def data_received(self, data: bytes) -> None:
if tail:
self.data_received(tail)
return
else:
if self._upgraded or self._parser is None:
# i.e. websocket connection, websocket parser is not set yet
self._tail += data

if self._upgraded or self._parser is None:
# i.e. websocket connection, websocket parser is not set yet
self._tail += data
return

# parse http messages
try:
messages, upgraded, tail = self._parser.feed_data(data)
except BaseException as underlying_exc:
if self.transport is not None:
# connection.release() could be called BEFORE
# data_received(), the transport is already
# closed in this case
self.transport.close()
# should_close is True after the call
if isinstance(underlying_exc, HttpProcessingError):
exc = HttpProcessingError(
code=underlying_exc.code,
message=underlying_exc.message,
headers=underlying_exc.headers,
)
else:
# parse http messages
try:
messages, upgraded, tail = self._parser.feed_data(data)
except BaseException as underlying_exc:
if self.transport is not None:
# connection.release() could be called BEFORE
# data_received(), the transport is already
# closed in this case
self.transport.close()
# should_close is True after the call
if isinstance(underlying_exc, HttpProcessingError):
exc = HttpProcessingError(
code=underlying_exc.code,
message=underlying_exc.message,
headers=underlying_exc.headers,
)
else:
exc = HttpProcessingError()
self.set_exception(exc, underlying_exc)
return

self._upgraded = upgraded

payload: Optional[StreamReader] = None
for message, payload in messages:
if message.should_close:
self._should_close = True

self._payload = payload

if self._skip_payload or message.code in EMPTY_BODY_STATUS_CODES:
self.feed_data((message, EMPTY_PAYLOAD), 0)
else:
self.feed_data((message, payload), 0)
if payload is not None:
# new message(s) was processed
# register timeout handler unsubscribing
# either on end-of-stream or immediately for
# EMPTY_PAYLOAD
if payload is not EMPTY_PAYLOAD:
payload.on_eof(self._drop_timeout)
else:
self._drop_timeout()
exc = HttpProcessingError()
self.set_exception(exc, underlying_exc)
return

if tail:
if upgraded:
self.data_received(tail)
else:
self._tail = tail
self._upgraded = upgraded

payload: Optional[StreamReader] = None
for message, payload in messages:
if message.should_close:
self._should_close = True

self._payload = payload

if self._skip_payload or message.code in EMPTY_BODY_STATUS_CODES:
self.feed_data((message, EMPTY_PAYLOAD), 0)
else:
self.feed_data((message, payload), 0)

if payload is not None:
# new message(s) was processed
# register timeout handler unsubscribing
# either on end-of-stream or immediately for
# EMPTY_PAYLOAD
if payload is not EMPTY_PAYLOAD:
payload.on_eof(self._drop_timeout)
else:
self._drop_timeout()

if upgraded and tail:
self.data_received(tail)
84 changes: 83 additions & 1 deletion tests/test_client_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,89 @@ async def test_uncompleted_message(loop) -> None:
assert dict(exc.message.headers) == {"Location": "http://python.org/"}


async def test_client_protocol_readuntil_eof(loop) -> None:
async def test_data_received_after_close(loop: asyncio.AbstractEventLoop) -> None:
proto = ResponseHandler(loop=loop)
transport = mock.Mock()
proto.connection_made(transport)
proto.set_response_params(read_until_eof=True)
proto.close()
assert transport.close.called
transport.close.reset_mock()
proto.data_received(b"HTTP\r\n\r\n")
assert proto.should_close
assert not transport.close.called
assert isinstance(proto.exception(), http.HttpProcessingError)


async def test_multiple_responses_one_byte_at_a_time(
loop: asyncio.AbstractEventLoop,
) -> None:
proto = ResponseHandler(loop=loop)
proto.connection_made(mock.Mock())
conn = mock.Mock(protocol=proto)
proto.set_response_params(read_until_eof=True)

for _ in range(2):
messages = (
b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nab"
b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\ncd"
b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nef"
)
for i in range(len(messages)):
proto.data_received(messages[i : i + 1])

expected = [b"ab", b"cd", b"ef"]
for payload in expected:
response = ClientResponse(
"get",
URL("http://def-cl-resp.org"),
writer=mock.Mock(),
continue100=None,
timer=TimerNoop(),
request_info=mock.Mock(),
traces=[],
loop=loop,
session=mock.Mock(),
)
await response.start(conn)
await response.read() == payload


async def test_unexpected_exception_during_data_received(
loop: asyncio.AbstractEventLoop,
) -> None:
proto = ResponseHandler(loop=loop)

class PatchableHttpResponseParser(http.HttpResponseParser):
"""Subclass of HttpResponseParser to make it patchable."""

with mock.patch(
"aiohttp.client_proto.HttpResponseParser", PatchableHttpResponseParser
):
proto.connection_made(mock.Mock())
conn = mock.Mock(protocol=proto)
proto.set_response_params(read_until_eof=True)
proto.data_received(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nab")
response = ClientResponse(
"get",
URL("http://def-cl-resp.org"),
writer=mock.Mock(),
continue100=None,
timer=TimerNoop(),
request_info=mock.Mock(),
traces=[],
loop=loop,
session=mock.Mock(),
)
await response.start(conn)
await response.read() == b"ab"
with mock.patch.object(proto._parser, "feed_data", side_effect=ValueError):
proto.data_received(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\ncd")

assert isinstance(proto.exception(), http.HttpProcessingError)


async def test_client_protocol_readuntil_eof(loop: asyncio.AbstractEventLoop) -> None:
proto = ResponseHandler(loop=loop)
transport = mock.Mock()
proto.connection_made(transport)
Expand Down

0 comments on commit 7e628f4

Please sign in to comment.