From ac4315056cabac57d282bca6921ed3c0a11716f2 Mon Sep 17 00:00:00 2001 From: eltbus <33374178+eltbus@users.noreply.github.com> Date: Fri, 16 Feb 2024 15:42:49 +0100 Subject: [PATCH 1/3] BaseParser must be provided a callbacks mapping --- multipart/multipart.py | 20 ++++++++------------ tests/test_multipart.py | 4 ++-- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/multipart/multipart.py b/multipart/multipart.py index ea8ccca..5dff296 100644 --- a/multipart/multipart.py +++ b/multipart/multipart.py @@ -16,7 +16,7 @@ from .exceptions import FileError, FormParserError, MultipartParseError, QuerystringParseError if TYPE_CHECKING: # pragma: no cover - from typing import Callable, Protocol, TypedDict + from typing import Callable, Protocol, TypedDict, Mapping class QuerystringCallbacks(TypedDict, total=False): on_field_start: Callable[[], None] @@ -608,8 +608,9 @@ class BaseParser: performance. """ - def __init__(self) -> None: + def __init__(self, callbacks: Mapping) -> None: self.logger = logging.getLogger(__name__) + self.callbacks = callbacks def callback(self, name: str, data: bytes | None = None, start: int | None = None, end: int | None = None): """This function calls a provided callback with some data. If the @@ -653,9 +654,9 @@ def set_callback(self, name: str, new_func: Callable[..., Any] | None) -> None: exist). """ if new_func is None: - self.callbacks.pop("on_" + name, None) + self.callbacks.pop("on_" + name, None) # TODO: MutableMapping breaks compatibility with TypedDict else: - self.callbacks["on_" + name] = new_func + self.callbacks["on_" + name] = new_func # TODO: MutableMapping breaks compatibility with TypedDict def close(self): pass # pragma: no cover @@ -696,8 +697,7 @@ class OctetStreamParser(BaseParser): """ def __init__(self, callbacks: OctetStreamCallbacks = {}, max_size: float = float("inf")): - super().__init__() - self.callbacks = callbacks + super().__init__(callbacks) self._started = False if not isinstance(max_size, Number) or max_size < 1: @@ -795,12 +795,10 @@ class QuerystringParser(BaseParser): def __init__( self, callbacks: QuerystringCallbacks = {}, strict_parsing: bool = False, max_size: float = float("inf") ) -> None: - super().__init__() + super().__init__(callbacks) self.state = QuerystringState.BEFORE_FIELD self._found_sep = False - self.callbacks = callbacks - # Max-size stuff if not isinstance(max_size, Number) or max_size < 1: raise ValueError("max_size must be a positive number, not %r" % max_size) @@ -1055,12 +1053,10 @@ def __init__( self, boundary: bytes | str, callbacks: MultipartCallbacks = {}, max_size: float = float("inf") ) -> None: # Initialize parser state. - super().__init__() + super().__init__(callbacks) self.state = MultipartState.START self.index = self.flags = 0 - self.callbacks = callbacks - if not isinstance(max_size, Number) or max_size < 1: raise ValueError("max_size must be a positive number, not %r" % max_size) self.max_size = max_size diff --git a/tests/test_multipart.py b/tests/test_multipart.py index 93fd38d..086d32e 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -290,8 +290,8 @@ def test_handles_rfc_2231(self): class TestBaseParser(unittest.TestCase): def setUp(self): - self.b = BaseParser() - self.b.callbacks = {} + callbacks = {} + self.b = BaseParser(callbacks) def test_callbacks(self): called = 0 From 4a6f0ea1b16acc3f6b8a5cb355c464cf3b9cc802 Mon Sep 17 00:00:00 2001 From: eltbus <33374178+eltbus@users.noreply.github.com> Date: Fri, 16 Feb 2024 15:57:02 +0100 Subject: [PATCH 2/3] Applied formatting --- multipart/multipart.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/multipart/multipart.py b/multipart/multipart.py index 5dff296..be24ab2 100644 --- a/multipart/multipart.py +++ b/multipart/multipart.py @@ -16,7 +16,7 @@ from .exceptions import FileError, FormParserError, MultipartParseError, QuerystringParseError if TYPE_CHECKING: # pragma: no cover - from typing import Callable, Protocol, TypedDict, Mapping + from typing import Callable, Mapping, Protocol, TypedDict class QuerystringCallbacks(TypedDict, total=False): on_field_start: Callable[[], None] From e0ad7509027b7c39066cc19e706fa7ba7a05be77 Mon Sep 17 00:00:00 2001 From: eltbus <33374178+eltbus@users.noreply.github.com> Date: Sat, 17 Feb 2024 08:17:57 +0100 Subject: [PATCH 3/3] Replace Mapping with a Base TypedDict --- multipart/multipart.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/multipart/multipart.py b/multipart/multipart.py index be24ab2..bd836bd 100644 --- a/multipart/multipart.py +++ b/multipart/multipart.py @@ -16,21 +16,22 @@ from .exceptions import FileError, FormParserError, MultipartParseError, QuerystringParseError if TYPE_CHECKING: # pragma: no cover - from typing import Callable, Mapping, Protocol, TypedDict + from typing import Callable, Protocol, TypedDict - class QuerystringCallbacks(TypedDict, total=False): + class BaseCallbacks(TypedDict, total=False): + on_end: Callable[[], None] + + class QuerystringCallbacks(BaseCallbacks, total=False): on_field_start: Callable[[], None] on_field_name: Callable[[bytes, int, int], None] on_field_data: Callable[[bytes, int, int], None] on_field_end: Callable[[], None] - on_end: Callable[[], None] - class OctetStreamCallbacks(TypedDict, total=False): + class OctetStreamCallbacks(BaseCallbacks, total=False): on_start: Callable[[], None] on_data: Callable[[bytes, int, int], None] - on_end: Callable[[], None] - class MultipartCallbacks(TypedDict, total=False): + class MultipartCallbacks(BaseCallbacks, total=False): on_part_begin: Callable[[], None] on_part_data: Callable[[bytes, int, int], None] on_part_end: Callable[[], None] @@ -39,7 +40,6 @@ class MultipartCallbacks(TypedDict, total=False): on_header_value: Callable[[bytes, int, int], None] on_header_end: Callable[[], None] on_headers_finished: Callable[[], None] - on_end: Callable[[], None] class FormParserConfig(TypedDict): UPLOAD_DIR: str | None @@ -608,7 +608,7 @@ class BaseParser: performance. """ - def __init__(self, callbacks: Mapping) -> None: + def __init__(self, callbacks: BaseCallbacks) -> None: self.logger = logging.getLogger(__name__) self.callbacks = callbacks @@ -654,9 +654,9 @@ def set_callback(self, name: str, new_func: Callable[..., Any] | None) -> None: exist). """ if new_func is None: - self.callbacks.pop("on_" + name, None) # TODO: MutableMapping breaks compatibility with TypedDict + self.callbacks.pop("on_" + name, None) else: - self.callbacks["on_" + name] = new_func # TODO: MutableMapping breaks compatibility with TypedDict + self.callbacks["on_" + name] = new_func def close(self): pass # pragma: no cover