Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PR #8933/8f3b1f44 backport][3.11] Small cleanups to the websocket frame sender #8979

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 34 additions & 11 deletions aiohttp/http_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class WSMsgType(IntEnum):
PACK_RANDBITS = Struct("!L").pack
MSG_SIZE: Final[int] = 2**14
DEFAULT_LIMIT: Final[int] = 2**16
MASK_LEN: Final[int] = 4


class WSMessage(NamedTuple):
Expand Down Expand Up @@ -625,12 +626,18 @@ async def _send_frame(
if self._closing and not (opcode & WSMsgType.CLOSE):
raise ConnectionResetError("Cannot write to closing transport")

# RSV are the reserved bits in the frame header. They are used to
# indicate that the frame is using an extension.
# https://datatracker.ietf.org/doc/html/rfc6455#section-5.2
rsv = 0

# Only compress larger packets (disabled)
# Does small packet needs to be compressed?
# if self.compress and opcode < 8 and len(message) > 124:
if (compress or self.compress) and opcode < 8:
# RSV1 (rsv = 0x40) is set for compressed frames
# https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.1
rsv = 0x40

if compress:
# Do not set self._compress if compressing is for this frame
compressobj = self._make_compress_obj(compress)
Expand All @@ -649,40 +656,56 @@ async def _send_frame(
)
if message.endswith(_WS_DEFLATE_TRAILING):
message = message[:-4]
rsv = rsv | 0x40

msg_length = len(message)

use_mask = self.use_mask
if use_mask:
mask_bit = 0x80
else:
mask_bit = 0
mask_bit = 0x80 if use_mask else 0

# Depending on the message length, the header is assembled differently.
# The first byte is reserved for the opcode and the RSV bits.
first_byte = 0x80 | rsv | opcode
if msg_length < 126:
header = PACK_LEN1(0x80 | rsv | opcode, msg_length | mask_bit)
header = PACK_LEN1(first_byte, msg_length | mask_bit)
header_len = 2
elif msg_length < (1 << 16):
header = PACK_LEN2(0x80 | rsv | opcode, 126 | mask_bit, msg_length)
header = PACK_LEN2(first_byte, 126 | mask_bit, msg_length)
header_len = 4
else:
header = PACK_LEN3(0x80 | rsv | opcode, 127 | mask_bit, msg_length)
header = PACK_LEN3(first_byte, 127 | mask_bit, msg_length)
header_len = 10

# https://datatracker.ietf.org/doc/html/rfc6455#section-5.3
# If we are using a mask, we need to generate it randomly
# and apply it to the message before sending it. A mask is
# a 32-bit value that is applied to the message using a
# bitwise XOR operation. It is used to prevent certain types
# of attacks on the websocket protocol. The mask is only used
# when aiohttp is acting as a client. Servers do not use a mask.
if use_mask:
mask = PACK_RANDBITS(self.get_random_bits())
message = bytearray(message)
_websocket_mask(mask, message)
self._write(header + mask + message)
self._output_size += len(header) + len(mask) + msg_length
self._output_size += header_len + MASK_LEN + msg_length

else:
if msg_length > MSG_SIZE:
self._write(header)
self._write(message)
else:
self._write(header + message)

self._output_size += len(header) + msg_length
self._output_size += header_len + msg_length

# It is safe to return control to the event loop when using compression
# after this point as we have already sent or buffered all the data.

# Once we have written output_size up to the limit, we call the
# drain helper which waits for the transport to be ready to accept
# more data. This is a flow control mechanism to prevent the buffer
# from growing too large. The drain helper will return right away
# if the writer is not paused.
if self._output_size > self._limit:
self._output_size = 0
await self.protocol._drain_helper()
Expand Down
Loading