Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add typing information #277

Merged
merged 2 commits into from
Nov 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@ source =

[flake8]
max-complexity = 10
max-line-length = 120
exclude =
hpack/huffman_constants.py
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
author_email='cory@lukasa.co.uk',
url='https://github.com/python-hyper/hpack',
packages=find_packages(where="src"),
package_data={'hpack': []},
package_data={'hpack': ['py.typed']},
package_dir={'': 'src'},
python_requires='>=3.9.0',
license='MIT License',
Expand Down
1 change: 0 additions & 1 deletion src/hpack/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
hpack
~~~~~
Expand Down
1 change: 0 additions & 1 deletion src/hpack/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
hyper/http20/exceptions
~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
98 changes: 48 additions & 50 deletions src/hpack/hpack.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# -*- coding: utf-8 -*-
"""
hpack/hpack
~~~~~~~~~~~

Implements the HPACK header compression algorithm as detailed by the IETF.
"""
import logging
from typing import Any, Generator, Union

from .table import HeaderTable, table_entry_size
from .exceptions import (
Expand All @@ -16,7 +16,7 @@
REQUEST_CODES, REQUEST_CODES_LENGTH
)
from .huffman_table import decode_huffman
from .struct import HeaderTuple, NeverIndexedHeaderTuple
from .struct import HeaderTuple, NeverIndexedHeaderTuple, Headers

log = logging.getLogger(__name__)

Expand All @@ -29,31 +29,25 @@
# as prefix numbers are not zero indexed.
_PREFIX_BIT_MAX_NUMBERS = [(2 ** i) - 1 for i in range(9)]

try: # pragma: no cover
basestring = basestring
except NameError: # pragma: no cover
basestring = (str, bytes)


# We default the maximum header list we're willing to accept to 64kB. That's a
# lot of headers, but if applications want to raise it they can do.
DEFAULT_MAX_HEADER_LIST_SIZE = 2 ** 16


def _unicode_if_needed(header, raw):
def _unicode_if_needed(header: HeaderTuple, raw: bool) -> HeaderTuple:
"""
Provides a header as a unicode string if raw is False, otherwise returns
it as a bytestring.
"""
name = bytes(header[0])
value = bytes(header[1])
name = bytes(header[0]) # type: ignore
value = bytes(header[1]) # type: ignore
if not raw:
name = name.decode('utf-8')
value = value.decode('utf-8')
return header.__class__(name, value)
return header.__class__(name.decode('utf-8'), value.decode('utf-8'))
else:
return header.__class__(name, value)


def encode_integer(integer, prefix_bits):
def encode_integer(integer: int, prefix_bits: int) -> bytearray:
"""
This encodes an integer according to the wacky integer encoding rules
defined in the HPACK spec.
Expand Down Expand Up @@ -87,7 +81,7 @@ def encode_integer(integer, prefix_bits):
return bytearray(elements)


def decode_integer(data, prefix_bits):
def decode_integer(data: bytes, prefix_bits: int) -> tuple[int, int]:
"""
This decodes an integer according to the wacky integer encoding rules
defined in the HPACK spec. Returns a tuple of the decoded integer and the
Expand Down Expand Up @@ -128,7 +122,8 @@ def decode_integer(data, prefix_bits):
return number, index


def _dict_to_iterable(header_dict):
def _dict_to_iterable(header_dict: Union[dict[bytes, bytes], dict[str, str]]) \
-> Generator[Union[tuple[bytes, bytes], tuple[str, str]], None, None]:
"""
This converts a dictionary to an iterable of two-tuples. This is a
HPACK-specific function because it pulls "special-headers" out first and
Expand All @@ -140,19 +135,19 @@ def _dict_to_iterable(header_dict):
key=lambda k: not _to_bytes(k).startswith(b':')
)
for key in keys:
yield key, header_dict[key]
yield key, header_dict[key] # type: ignore


def _to_bytes(value):
def _to_bytes(value: Union[bytes, str, Any]) -> bytes:
"""
Convert anything to bytes through a UTF-8 encoded string
"""
t = type(value)
if t is bytes:
return value
return value # type: ignore
if t is not str:
value = str(value)
return value.encode("utf-8")
return value.encode("utf-8") # type: ignore


class Encoder:
Expand All @@ -161,27 +156,29 @@ class Encoder:
HTTP/2 header blocks.
"""

def __init__(self):
def __init__(self) -> None:
self.header_table = HeaderTable()
self.huffman_coder = HuffmanEncoder(
REQUEST_CODES, REQUEST_CODES_LENGTH
)
self.table_size_changes = []
self.table_size_changes: list[int] = []

@property
def header_table_size(self):
def header_table_size(self) -> int:
"""
Controls the size of the HPACK header table.
"""
return self.header_table.maxsize

@header_table_size.setter
def header_table_size(self, value):
def header_table_size(self, value: int) -> None:
self.header_table.maxsize = value
if self.header_table.resized:
self.table_size_changes.append(value)

def encode(self, headers, huffman=True):
def encode(self,
headers: Headers,
huffman: bool = True) -> bytes:
"""
Takes a set of headers and encodes them into a HPACK-encoded header
block.
Expand Down Expand Up @@ -256,13 +253,13 @@ def encode(self, headers, huffman=True):
header = (_to_bytes(header[0]), _to_bytes(header[1]))
header_block.append(self.add(header, sensitive, huffman))

header_block = b''.join(header_block)
encoded = b''.join(header_block)

log.debug("Encoded header block to %s", header_block)
log.debug("Encoded header block to %s", encoded)

return header_block
return encoded

def add(self, to_add, sensitive, huffman=False):
def add(self, to_add: tuple[bytes, bytes], sensitive: bool, huffman: bool = False) -> bytes:
"""
This function takes a header key-value tuple and serializes it.
"""
Expand Down Expand Up @@ -311,15 +308,15 @@ def add(self, to_add, sensitive, huffman=False):

return encoded

def _encode_indexed(self, index):
def _encode_indexed(self, index: int) -> bytes:
"""
Encodes a header using the indexed representation.
"""
field = encode_integer(index, 7)
field[0] |= 0x80 # we set the top bit
return bytes(field)

def _encode_literal(self, name, value, indexbit, huffman=False):
def _encode_literal(self, name: bytes, value: bytes, indexbit: bytes, huffman: bool = False) -> bytes:
"""
Encodes a header with a literal name and literal value. If ``indexing``
is True, the header will be added to the header table: otherwise it
Expand All @@ -340,7 +337,7 @@ def _encode_literal(self, name, value, indexbit, huffman=False):
[indexbit, bytes(name_len), name, bytes(value_len), value]
)

def _encode_indexed_literal(self, index, value, indexbit, huffman=False):
def _encode_indexed_literal(self, index: int, value: bytes, indexbit: bytes, huffman: bool = False) -> bytes:
"""
Encodes a header with an indexed name and a literal value and performs
incremental indexing.
Expand All @@ -362,16 +359,16 @@ def _encode_indexed_literal(self, index, value, indexbit, huffman=False):

return b''.join([bytes(prefix), bytes(value_len), value])

def _encode_table_size_change(self):
def _encode_table_size_change(self) -> bytes:
"""
Produces the encoded form of all header table size change context
updates.
"""
block = b''
for size_bytes in self.table_size_changes:
size_bytes = encode_integer(size_bytes, 5)
size_bytes[0] |= 0x20
block += bytes(size_bytes)
b = encode_integer(size_bytes, 5)
b[0] |= 0x20
block += bytes(b)
self.table_size_changes = []
return block

Expand All @@ -397,7 +394,7 @@ class Decoder:
Defaults to 64kB.
:type max_header_list_size: ``int``
"""
def __init__(self, max_header_list_size=DEFAULT_MAX_HEADER_LIST_SIZE):
def __init__(self, max_header_list_size: int = DEFAULT_MAX_HEADER_LIST_SIZE) -> None:
self.header_table = HeaderTable()

#: The maximum decompressed size we will allow for any single header
Expand Down Expand Up @@ -426,17 +423,17 @@ def __init__(self, max_header_list_size=DEFAULT_MAX_HEADER_LIST_SIZE):
self.max_allowed_table_size = self.header_table.maxsize

@property
def header_table_size(self):
def header_table_size(self) -> int:
"""
Controls the size of the HPACK header table.
"""
return self.header_table.maxsize

@header_table_size.setter
def header_table_size(self, value):
def header_table_size(self, value: int) -> None:
self.header_table.maxsize = value

def decode(self, data, raw=False):
def decode(self, data: bytes, raw: bool = False) -> Headers:
"""
Takes an HPACK-encoded header block and decodes it into a header set.

Expand All @@ -454,7 +451,7 @@ def decode(self, data, raw=False):
log.debug("Decoding %s", data)

data_mem = memoryview(data)
headers = []
headers: list[HeaderTuple] = []
data_len = len(data)
inflated_size = 0
current_index = 0
Expand Down Expand Up @@ -501,7 +498,7 @@ def decode(self, data, raw=False):

if header:
headers.append(header)
inflated_size += table_entry_size(*header)
inflated_size += table_entry_size(header[0], header[1])

if inflated_size > self.max_header_list_size:
raise OversizedHeaderListError(
Expand All @@ -521,7 +518,7 @@ def decode(self, data, raw=False):
except UnicodeDecodeError:
raise HPACKDecodingError("Unable to decode headers as UTF-8.")

def _assert_valid_table_size(self):
def _assert_valid_table_size(self) -> None:
"""
Check that the table size set by the encoder is lower than the maximum
we expect to have.
Expand All @@ -531,7 +528,7 @@ def _assert_valid_table_size(self):
"Encoder did not shrink table size to within the max"
)

def _update_encoding_context(self, data):
def _update_encoding_context(self, data: bytes) -> int:
"""
Handles a byte that updates the encoding context.
"""
Expand All @@ -544,7 +541,7 @@ def _update_encoding_context(self, data):
self.header_table_size = new_size
return consumed

def _decode_indexed(self, data):
def _decode_indexed(self, data: bytes) -> tuple[HeaderTuple, int]:
"""
Decodes a header represented using the indexed representation.
"""
Expand All @@ -553,13 +550,13 @@ def _decode_indexed(self, data):
log.debug("Decoded %s, consumed %d", header, consumed)
return header, consumed

def _decode_literal_no_index(self, data):
def _decode_literal_no_index(self, data: bytes) -> tuple[HeaderTuple, int]:
return self._decode_literal(data, False)

def _decode_literal_index(self, data):
def _decode_literal_index(self, data: bytes) -> tuple[HeaderTuple, int]:
return self._decode_literal(data, True)

def _decode_literal(self, data, should_index):
def _decode_literal(self, data: bytes, should_index: bool) -> tuple[HeaderTuple, int]:
"""
Decodes a header represented with a literal.
"""
Expand All @@ -577,7 +574,7 @@ def _decode_literal(self, data, should_index):
high_byte = data[0]
indexed_name = high_byte & 0x0F
name_len = 4
not_indexable = high_byte & 0x10
not_indexable = bool(high_byte & 0x10)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good one!

if indexed_name:
# Indexed header name.
Expand Down Expand Up @@ -616,6 +613,7 @@ def _decode_literal(self, data, should_index):

# If we have been told never to index the header field, encode that in
# the tuple we use.
header: HeaderTuple
if not_indexable:
header = NeverIndexedHeaderTuple(name, value)
else:
Expand Down
17 changes: 8 additions & 9 deletions src/hpack/huffman.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
hpack/huffman_decoder
~~~~~~~~~~~~~~~~~~~~~
Expand All @@ -13,11 +12,11 @@ class HuffmanEncoder:
Encodes a string according to the Huffman encoding table defined in the
HPACK specification.
"""
def __init__(self, huffman_code_list, huffman_code_list_lengths):
def __init__(self, huffman_code_list: list[int], huffman_code_list_lengths: list[int]) -> None:
self.huffman_code_list = huffman_code_list
self.huffman_code_list_lengths = huffman_code_list_lengths

def encode(self, bytes_to_encode):
def encode(self, bytes_to_encode: bytes) -> bytes:
"""
Given a string of bytes, encodes them according to the HPACK Huffman
specification.
Expand Down Expand Up @@ -48,19 +47,19 @@ def encode(self, bytes_to_encode):

# Convert the number to hex and strip off the leading '0x' and the
# trailing 'L', if present.
final_num = hex(final_num)[2:].rstrip('L')
s = hex(final_num)[2:].rstrip('L')

Kriechi marked this conversation as resolved.
Show resolved Hide resolved
# If this is odd, prepend a zero.
final_num = '0' + final_num if len(final_num) % 2 != 0 else final_num
s = '0' + s if len(s) % 2 != 0 else s

# This number should have twice as many digits as bytes. If not, we're
# missing some leading zeroes. Work out how many bytes we want and how
# many digits we have, then add the missing zero digits to the front.
total_bytes = (final_int_len + bits_to_be_padded) // 8
expected_digits = total_bytes * 2

if len(final_num) != expected_digits:
missing_digits = expected_digits - len(final_num)
final_num = ('0' * missing_digits) + final_num
if len(s) != expected_digits:
missing_digits = expected_digits - len(s)
s = ('0' * missing_digits) + s

return bytes.fromhex(final_num)
return bytes.fromhex(s)
1 change: 0 additions & 1 deletion src/hpack/huffman_constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
hpack/huffman_constants
~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
Loading
Loading