diff --git a/multipart/decoders.py b/multipart/decoders.py index 417650c..e401fa0 100644 --- a/multipart/decoders.py +++ b/multipart/decoders.py @@ -73,7 +73,7 @@ def write(self, data): # Return the length of the data to indicate no error. return len(data) - def close(self): + def close(self) -> None: """Close this decoder. If the underlying object has a `close()` method, this function will call it. """ diff --git a/multipart/multipart.py b/multipart/multipart.py index 21d9ac1..0c7c447 100644 --- a/multipart/multipart.py +++ b/multipart/multipart.py @@ -1,5 +1,6 @@ from __future__ import annotations +import io import logging import os import shutil @@ -534,14 +535,14 @@ def _get_disk_file(self): self._actual_file_name = fname return tmp_file - def write(self, data: bytes): + def write(self, data: bytes) -> int: """Write some data to the File. :param data: a bytestring """ return self.on_data(data) - def on_data(self, data: bytes): + def on_data(self, data: bytes) -> int: """This method is a callback that will be called whenever data is written to the File. @@ -652,7 +653,7 @@ def callback(self, name: str, data: bytes | None = None, start: int | None = Non self.logger.debug("Calling %s with no data", name) func() - def set_callback(self, name: str, new_func): + def set_callback(self, name: str, new_func: Callable[..., Any] | None) -> None: """Update the function for a callback. Removes from the callbacks dict if new_func is None. @@ -1096,7 +1097,7 @@ def __init__(self, boundary: bytes | str, callbacks: MultipartCallbacks = {}, ma # Note: the +8 is since we can have, at maximum, "\r\n--" + boundary + # "--\r\n" at the final boundary, and the length of '\r\n--' and # '--\r\n' is 8 bytes. - self.lookbehind = [NULL for x in range(len(boundary) + 8)] + self.lookbehind = [NULL for _ in range(len(boundary) + 8)] def write(self, data: bytes) -> int: """Write some data to the parser, which will perform size verification, @@ -1642,22 +1643,23 @@ def __init__( # Depending on the Content-Type, we instantiate the correct parser. if content_type == "application/octet-stream": - f: FileProtocol | None = None + file: FileProtocol = None # type: ignore def on_start() -> None: - nonlocal f - f = FileClass(file_name, None, config=self.config) + nonlocal file + file = FileClass(file_name, None, config=self.config) def on_data(data: bytes, start: int, end: int) -> None: - nonlocal f - f.write(data[start:end]) + nonlocal file + file.write(data[start:end]) def _on_end() -> None: + nonlocal file # Finalize the file itself. - f.finalize() + file.finalize() # Call our callback. - on_file(f) + on_file(file) # Call the on-end callback. if self.on_end is not None: @@ -1672,7 +1674,7 @@ def _on_end() -> None: elif content_type == "application/x-www-form-urlencoded" or content_type == "application/x-url-encoded": name_buffer: list[bytes] = [] - f: FieldProtocol | None = None + f: FieldProtocol = None # type: ignore def on_field_start() -> None: pass @@ -1747,13 +1749,13 @@ def on_part_end() -> None: else: on_field(f) - def on_header_field(data: bytes, start: int, end: int): + def on_header_field(data: bytes, start: int, end: int) -> None: header_name.append(data[start:end]) - def on_header_value(data: bytes, start: int, end: int): + def on_header_value(data: bytes, start: int, end: int) -> None: header_value.append(data[start:end]) - def on_header_end(): + def on_header_end() -> None: headers[b"".join(header_name)] = b"".join(header_value) del header_name[:] del header_value[:] @@ -1855,7 +1857,13 @@ def __repr__(self) -> str: return "{}(content_type={!r}, parser={!r})".format(self.__class__.__name__, self.content_type, self.parser) -def create_form_parser(headers, on_field, on_file, trust_x_headers=False, config={}): +def create_form_parser( + headers: dict[str, bytes], + on_field: OnFieldCallback, + on_file: OnFileCallback, + trust_x_headers: bool = False, + config={}, +): """This function is a helper function to aid in creating a FormParser instances. Given a dictionary-like headers object, it will determine the correct information needed, instantiate a FormParser with the @@ -1898,7 +1906,14 @@ def create_form_parser(headers, on_field, on_file, trust_x_headers=False, config return form_parser -def parse_form(headers, input_stream, on_field, on_file, chunk_size=1048576, **kwargs): +def parse_form( + headers: dict[str, bytes], + input_stream: io.FileIO, + on_field: OnFieldCallback, + on_file: OnFileCallback, + chunk_size: int = 1048576, + **kwargs, +): """This function is useful if you just want to parse a request body, without too much work. Pass it a dictionary-like object of the request's headers, and a file-like object for the input stream, along with two diff --git a/tests/test_multipart.py b/tests/test_multipart.py index 9dc10bb..79968e0 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import random import sys @@ -288,19 +290,19 @@ def setUp(self): self.b.callbacks = {} def test_callbacks(self): - # The stupid list-ness is to get around lack of nonlocal on py2 - l = [0] + called = 0 def on_foo(): - l[0] += 1 + nonlocal called + called += 1 self.b.set_callback("foo", on_foo) self.b.callback("foo") - self.assertEqual(l[0], 1) + self.assertEqual(called, 1) self.b.set_callback("foo", None) self.b.callback("foo") - self.assertEqual(l[0], 1) + self.assertEqual(called, 1) class TestQuerystringParser(unittest.TestCase): @@ -316,15 +318,15 @@ def setUp(self): self.reset() def reset(self): - self.f = [] + self.f: list[tuple[bytes, bytes]] = [] - name_buffer = [] - data_buffer = [] + name_buffer: list[bytes] = [] + data_buffer: list[bytes] = [] - def on_field_name(data, start, end): + def on_field_name(data: bytes, start: int, end: int) -> None: name_buffer.append(data[start:end]) - def on_field_data(data, start, end): + def on_field_data(data: bytes, start: int, end: int) -> None: data_buffer.append(data[start:end]) def on_field_end(): @@ -705,13 +707,13 @@ def split_all(val): class TestFormParser(unittest.TestCase): def make(self, boundary, config={}): self.ended = False - self.files = [] - self.fields = [] + self.files: list[File] = [] + self.fields: list[Field] = [] - def on_field(f): + def on_field(f: Field) -> None: self.fields.append(f) - def on_file(f): + def on_file(f: File) -> None: self.files.append(f) def on_end():