Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TypedDict callbacks #98

Merged
merged 1 commit into from
Feb 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ lib64
pip-log.txt

# Unit test / coverage reports
.coverage.*
.coverage*
.tox
nosetests.xml

Expand Down
75 changes: 52 additions & 23 deletions multipart/multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,38 @@
from enum import IntEnum
from io import BytesIO
from numbers import Number
from typing import Dict, Tuple, Union
from typing import TYPE_CHECKING

from .decoders import Base64Decoder, QuotedPrintableDecoder
from .exceptions import FileError, FormParserError, MultipartParseError, QuerystringParseError

if TYPE_CHECKING: # pragma: no cover
from typing import Callable, TypedDict

class QuerystringCallbacks(TypedDict, total=False):
on_field_start: Callable[[], None]
on_field_name: Callable[[bytes, int, int], None]
on_field_data: Callable[[bytes, int, int], None]
on_field_end: Callable[[], None]
on_end: Callable[[], None]

class OctetStreamCallbacks(TypedDict, total=False):
on_start: Callable[[], None]
on_data: Callable[[bytes, int, int], None]
on_end: Callable[[], None]

class MultipartCallbacks(TypedDict, total=False):
on_part_begin: Callable[[], None]
on_part_data: Callable[[bytes, int, int], None]
on_part_end: Callable[[], None]
on_headers_begin: Callable[[], None]
on_header_field: Callable[[bytes, int, int], None]
on_header_value: Callable[[bytes, int, int], None]
on_header_end: Callable[[], None]
on_headers_finished: Callable[[], None]
on_end: Callable[[], None]


# Unique missing object.
_missing = object()

Expand Down Expand Up @@ -86,7 +113,7 @@ def join_bytes(b):
return bytes(list(b))


def parse_options_header(value: Union[str, bytes]) -> Tuple[bytes, Dict[bytes, bytes]]:
def parse_options_header(value: str | bytes) -> tuple[bytes, dict[bytes, bytes]]:
"""
Parses a Content-Type header into a value in the following format:
(content_type, {parameters})
Expand Down Expand Up @@ -148,15 +175,15 @@ class Field:
:param name: the name of the form field
"""

def __init__(self, name):
def __init__(self, name: str):
self._name = name
self._value = []
self._value: list[bytes] = []

# We cache the joined version of _value for speed.
self._cache = _missing

@classmethod
def from_value(klass, name, value):
def from_value(cls, name: str, value: bytes | None) -> Field:
"""Create an instance of a :class:`Field`, and set the corresponding
value - either None or an actual value. This method will also
finalize the Field itself.
Expand All @@ -166,22 +193,22 @@ def from_value(klass, name, value):
None
"""

f = klass(name)
f = cls(name)
if value is None:
f.set_none()
else:
f.write(value)
f.finalize()
return f

def write(self, data):
def write(self, data: bytes) -> int:
"""Write some data into the form field.

:param data: a bytestring
"""
return self.on_data(data)

def on_data(self, data):
def on_data(self, data: bytes) -> int:
"""This method is a callback that will be called whenever data is
written to the Field.

Expand All @@ -191,24 +218,24 @@ def on_data(self, data):
self._cache = _missing
return len(data)

def on_end(self):
def on_end(self) -> None:
"""This method is called whenever the Field is finalized."""
if self._cache is _missing:
self._cache = b"".join(self._value)

def finalize(self):
def finalize(self) -> None:
"""Finalize the form field."""
self.on_end()

def close(self):
def close(self) -> None:
"""Close the Field object. This will free any underlying cache."""
# Free our value array.
if self._cache is _missing:
self._cache = b"".join(self._value)

del self._value

def set_none(self):
def set_none(self) -> None:
"""Some fields in a querystring can possibly have a value of None - for
example, the string "foo&bar=&baz=asdf" will have a field with the
name "foo" and value None, one with name "bar" and value "", and one
Expand All @@ -218,7 +245,7 @@ def set_none(self):
self._cache = None

@property
def field_name(self):
def field_name(self) -> str:
"""This property returns the name of the field."""
return self._name

Expand All @@ -230,13 +257,13 @@ def value(self):

return self._cache

def __eq__(self, other):
def __eq__(self, other: object) -> bool:
if isinstance(other, Field):
return self.field_name == other.field_name and self.value == other.value
else:
return NotImplemented

def __repr__(self):
def __repr__(self) -> str:
if len(self.value) > 97:
# We get the repr, and then insert three dots before the final
# quote.
Expand Down Expand Up @@ -553,7 +580,7 @@ class BaseParser:
def __init__(self):
self.logger = logging.getLogger(__name__)

def callback(self, name, data=None, start=None, end=None):
def callback(self, name: str, data=None, start=None, end=None):
"""This function calls a provided callback with some data. If the
callback is not set, will do nothing.

Expand Down Expand Up @@ -584,7 +611,7 @@ def callback(self, name, data=None, start=None, end=None):
self.logger.debug("Calling %s with no data", name)
func()

def set_callback(self, name, new_func):
def set_callback(self, name: str, new_func):
"""Update the function for a callback. Removes from the callbacks dict
if new_func is None.

Expand Down Expand Up @@ -637,7 +664,7 @@ class OctetStreamParser(BaseParser):
i.e. unbounded.
"""

def __init__(self, callbacks={}, max_size=float("inf")):
def __init__(self, callbacks: OctetStreamCallbacks = {}, max_size=float("inf")):
super().__init__()
self.callbacks = callbacks
self._started = False
Expand All @@ -647,7 +674,7 @@ def __init__(self, callbacks={}, max_size=float("inf")):
self.max_size = max_size
self._current_size = 0

def write(self, data):
def write(self, data: bytes):
"""Write some data to the parser, which will perform size verification,
and then pass the data to the underlying callback.

Expand Down Expand Up @@ -732,7 +759,9 @@ class QuerystringParser(BaseParser):
i.e. unbounded.
"""

def __init__(self, callbacks={}, strict_parsing=False, max_size=float("inf")):
state: QuerystringState

def __init__(self, callbacks: QuerystringCallbacks = {}, strict_parsing=False, max_size=float("inf")):
super().__init__()
self.state = QuerystringState.BEFORE_FIELD
self._found_sep = False
Expand All @@ -748,7 +777,7 @@ def __init__(self, callbacks={}, strict_parsing=False, max_size=float("inf")):
# Should parsing be strict?
self.strict_parsing = strict_parsing

def write(self, data):
def write(self, data: bytes):
"""Write some data to the parser, which will perform size verification,
parse into either a field name or value, and then pass the
corresponding data to the underlying callback. If an error is
Expand Down Expand Up @@ -780,7 +809,7 @@ def write(self, data):

return l

def _internal_write(self, data, length):
def _internal_write(self, data: bytes, length: int):
state = self.state
strict_parsing = self.strict_parsing
found_sep = self._found_sep
Expand Down Expand Up @@ -989,7 +1018,7 @@ class MultipartParser(BaseParser):
i.e. unbounded.
"""

def __init__(self, boundary, callbacks={}, max_size=float("inf")):
def __init__(self, boundary, callbacks: MultipartCallbacks = {}, max_size=float("inf")):
# Initialize parser state.
super().__init__()
self.state = MultipartState.START
Expand Down
16 changes: 7 additions & 9 deletions tests/test_multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,9 +333,9 @@ def on_field_end():
del name_buffer[:]
del data_buffer[:]

callbacks = {"on_field_name": on_field_name, "on_field_data": on_field_data, "on_field_end": on_field_end}

self.p = QuerystringParser(callbacks)
self.p = QuerystringParser(
callbacks={"on_field_name": on_field_name, "on_field_data": on_field_data, "on_field_end": on_field_end}
)

def test_simple_querystring(self):
self.p.write(b"foo=bar")
Expand Down Expand Up @@ -464,18 +464,16 @@ def setUp(self):
self.started = 0
self.finished = 0

def on_start():
def on_start() -> None:
self.started += 1

def on_data(data, start, end):
def on_data(data: bytes, start: int, end: int) -> None:
self.d.append(data[start:end])

def on_end():
def on_end() -> None:
self.finished += 1

callbacks = {"on_start": on_start, "on_data": on_data, "on_end": on_end}

self.p = OctetStreamParser(callbacks)
self.p = OctetStreamParser(callbacks={"on_start": on_start, "on_data": on_data, "on_end": on_end})

def assert_data(self, data, finalize=True):
self.assertEqual(b"".join(self.d), data)
Expand Down