Skip to content

Commit

Permalink
Fix WebSocket reader flow control calculations (#9685)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
(cherry picked from commit 5241897)
  • Loading branch information
bdraco committed Nov 11, 2024
1 parent bd3a3be commit 7027714
Show file tree
Hide file tree
Showing 16 changed files with 235 additions and 148 deletions.
1 change: 1 addition & 0 deletions CHANGES/9685.breaking.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
``FlowControlDataQueue`` has been replaced with the ``WebSocketDataQueue`` -- by :user:`bdraco`.
2 changes: 0 additions & 2 deletions aiohttp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@
EMPTY_PAYLOAD as EMPTY_PAYLOAD,
DataQueue as DataQueue,
EofStream as EofStream,
FlowControlDataQueue as FlowControlDataQueue,
StreamReader as StreamReader,
)
from .tracing import (
Expand Down Expand Up @@ -216,7 +215,6 @@
"DataQueue",
"EMPTY_PAYLOAD",
"EofStream",
"FlowControlDataQueue",
"StreamReader",
# tracing
"TraceConfig",
Expand Down
14 changes: 12 additions & 2 deletions aiohttp/_websocket/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,27 @@
from ..helpers import NO_EXTENSIONS

if TYPE_CHECKING or NO_EXTENSIONS: # pragma: no cover
from .reader_py import WebSocketReader as WebSocketReaderPython
from .reader_py import (
WebSocketDataQueue as WebSocketDataQueuePython,
WebSocketReader as WebSocketReaderPython,
)

WebSocketReader = WebSocketReaderPython
WebSocketDataQueue = WebSocketDataQueuePython
else:
try:
from .reader_c import ( # type: ignore[import-not-found]
WebSocketDataQueue as WebSocketDataQueueCython,
WebSocketReader as WebSocketReaderCython,
)

WebSocketReader = WebSocketReaderCython
WebSocketDataQueue = WebSocketDataQueueCython
except ImportError: # pragma: no cover
from .reader_py import WebSocketReader as WebSocketReaderPython
from .reader_py import (
WebSocketDataQueue as WebSocketDataQueuePython,
WebSocketReader as WebSocketReaderPython,
)

WebSocketReader = WebSocketReaderPython
WebSocketDataQueue = WebSocketDataQueuePython
23 changes: 21 additions & 2 deletions aiohttp/_websocket/reader_c.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,30 @@ cdef set MESSAGE_TYPES_WITH_CONTENT
cdef tuple EMPTY_FRAME
cdef tuple EMPTY_FRAME_ERROR

cdef class WebSocketDataQueue:

cdef unsigned int _size
cdef public object _protocol
cdef unsigned int _limit
cdef object _loop
cdef bint _eof
cdef object _waiter
cdef object _exception
cdef public object _buffer
cdef object _get_buffer
cdef object _put_buffer

cdef void _release_waiter(self)

@cython.locals(size="unsigned int")
cpdef void feed_data(self, object data)

@cython.locals(size="unsigned int")
cdef _read_from_buffer(self)

cdef class WebSocketReader:

cdef object queue
cdef object _queue_feed_data
cdef WebSocketDataQueue queue
cdef unsigned int _max_msg_size

cdef Exception _exc
Expand Down
105 changes: 91 additions & 14 deletions aiohttp/_websocket/reader_py.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
"""Reader for WebSocket protocol versions 13 and 8."""

from typing import Final, List, Optional, Set, Tuple, Union
import asyncio
import builtins
from collections import deque
from typing import Deque, Final, List, Optional, Set, Tuple, Type, Union

from ..base_protocol import BaseProtocol
from ..compression_utils import ZLibDecompressor
from ..helpers import set_exception
from ..streams import FlowControlDataQueue
from ..helpers import _EXC_SENTINEL, set_exception
from ..streams import EofStream
from .helpers import UNPACK_CLOSE_CODE, UNPACK_LEN3, websocket_mask
from .models import (
WS_DEFLATE_TRAILING,
Expand Down Expand Up @@ -40,15 +44,87 @@
TUPLE_NEW = tuple.__new__


class WebSocketReader:
class WebSocketDataQueue:
"""WebSocketDataQueue resumes and pauses an underlying stream.
It is a destination for WebSocket data.
"""

def __init__(
self, protocol: BaseProtocol, limit: int, *, loop: asyncio.AbstractEventLoop
) -> None:
self._size = 0
self._protocol = protocol
self._limit = limit * 2
self._loop = loop
self._eof = False
self._waiter: Optional[asyncio.Future[None]] = None
self._exception: Union[Type[BaseException], BaseException, None] = None
self._buffer: Deque[Tuple[WSMessage, int]] = deque()
self._get_buffer = self._buffer.popleft
self._put_buffer = self._buffer.append

def exception(self) -> Optional[Union[Type[BaseException], BaseException]]:
return self._exception

def set_exception(
self,
queue: FlowControlDataQueue[WSMessage],
max_msg_size: int,
compress: bool = True,
exc: Union[Type[BaseException], BaseException],
exc_cause: builtins.BaseException = _EXC_SENTINEL,
) -> None:
self._eof = True
self._exception = exc
if (waiter := self._waiter) is not None:
self._waiter = None
set_exception(waiter, exc, exc_cause)

def _release_waiter(self) -> None:
if (waiter := self._waiter) is None:
return
self._waiter = None
if not waiter.done():
waiter.set_result(None)

def feed_eof(self) -> None:
self._eof = True
self._release_waiter()

def feed_data(self, data: "WSMessage", size: int) -> None:
size = data.size
self._size += size
self._put_buffer((data, size))
self._release_waiter()
if self._size > self._limit and not self._protocol._reading_paused:
self._protocol.pause_reading()

async def read(self) -> WSMessage:
if not self._buffer and not self._eof:
assert not self._waiter
self._waiter = self._loop.create_future()
try:
await self._waiter
except (asyncio.CancelledError, asyncio.TimeoutError):
self._waiter = None
raise
return self._read_from_buffer()

def _read_from_buffer(self) -> WSMessage:
if self._buffer:
data, size = self._get_buffer()
self._size -= size
if self._size < self._limit and self._protocol._reading_paused:
self._protocol.resume_reading()
return data
if self._exception is not None:
raise self._exception
raise EofStream


class WebSocketReader:
def __init__(
self, queue: WebSocketDataQueue, max_msg_size: int, compress: bool = True
) -> None:
self.queue = queue
self._queue_feed_data = queue.feed_data
self._max_msg_size = max_msg_size

self._exc: Optional[Exception] = None
Expand Down Expand Up @@ -187,17 +263,18 @@ def _feed_data(self, data: bytes) -> None:
# bottleneck, so we use tuple.__new__ to improve performance.
# This is not type safe, but many tests should fail in
# test_client_ws_functional.py if this is wrong.
self._queue_feed_data(
self.queue.feed_data(
TUPLE_NEW(WSMessage, (WS_MSG_TYPE_TEXT, text, "")),
len(payload_merged),
)
else:
self._queue_feed_data(
self.queue.feed_data(
TUPLE_NEW(WSMessage, (WS_MSG_TYPE_BINARY, payload_merged, "")),
len(payload_merged),
)
elif opcode == OP_CODE_CLOSE:
if len(payload) >= 2:
payload_len = len(payload)
if payload_len >= 2:
close_code = UNPACK_CLOSE_CODE(payload[:2])[0]
if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES:
raise WebSocketError(
Expand All @@ -221,14 +298,14 @@ def _feed_data(self, data: bytes) -> None:
else:
msg = TUPLE_NEW(WSMessage, (WSMsgType.CLOSE, 0, ""))

self._queue_feed_data(msg, 0)
self.queue.feed_data(msg, 0)
elif opcode == OP_CODE_PING:
msg = TUPLE_NEW(WSMessage, (WSMsgType.PING, payload, ""))
self._queue_feed_data(msg, len(payload))
self.queue.feed_data(msg, len(payload))

elif opcode == OP_CODE_PONG:
msg = TUPLE_NEW(WSMessage, (WSMsgType.PONG, payload, ""))
self._queue_feed_data(msg, len(payload))
self.queue.feed_data(msg, len(payload))

else:
raise WebSocketError(
Expand Down
8 changes: 3 additions & 5 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from yarl import URL

from . import hdrs, http, payload
from ._websocket.reader import WebSocketDataQueue
from .abc import AbstractCookieJar
from .client_exceptions import (
ClientConnectionError,
Expand Down Expand Up @@ -100,8 +101,7 @@
strip_auth_from_url,
)
from .http import WS_KEY, HttpVersion, WebSocketReader, WebSocketWriter
from .http_websocket import WSHandshakeError, WSMessage, ws_ext_gen, ws_ext_parse
from .streams import FlowControlDataQueue
from .http_websocket import WSHandshakeError, ws_ext_gen, ws_ext_parse
from .tracing import Trace, TraceConfig
from .typedefs import JSONEncoder, LooseCookies, LooseHeaders, Query, StrOrURL

Expand Down Expand Up @@ -1098,9 +1098,7 @@ async def _ws_connect(

transport = conn.transport
assert transport is not None
reader: FlowControlDataQueue[WSMessage] = FlowControlDataQueue(
conn_proto, 2**16, loop=self._loop
)
reader = WebSocketDataQueue(conn_proto, 2**16, loop=self._loop)
conn_proto.set_parser(WebSocketReader(reader, max_msg_size), reader)
writer = WebSocketWriter(
conn_proto,
Expand Down
2 changes: 1 addition & 1 deletion aiohttp/client_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def set_exception(
def set_parser(self, parser: Any, payload: Any) -> None:
# TODO: actual types are:
# parser: WebSocketReader
# payload: FlowControlDataQueue
# payload: WebSocketDataQueue
# but they are not generi enough
# Need an ABC for both types
self._payload = payload
Expand Down
5 changes: 3 additions & 2 deletions aiohttp/client_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import attr

from ._websocket.reader import WebSocketDataQueue
from .client_exceptions import ClientError, ServerTimeoutError, WSMessageTypeError
from .client_reqrep import ClientResponse
from .helpers import calculate_timeout_when, set_result
Expand All @@ -19,7 +20,7 @@
WSMsgType,
)
from .http_websocket import _INTERNAL_RECEIVE_TYPES, WebSocketWriter
from .streams import EofStream, FlowControlDataQueue
from .streams import EofStream
from .typedefs import (
DEFAULT_JSON_DECODER,
DEFAULT_JSON_ENCODER,
Expand All @@ -45,7 +46,7 @@ class ClientWSTimeout:
class ClientWebSocketResponse:
def __init__(
self,
reader: "FlowControlDataQueue[WSMessage]",
reader: WebSocketDataQueue,
writer: WebSocketWriter,
protocol: Optional[str],
response: ClientResponse,
Expand Down
Loading

0 comments on commit 7027714

Please sign in to comment.