Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve type hints #110

Merged
merged 1 commit into from
Feb 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion multipart/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def write(self, data):
# Return the length of the data to indicate no error.
return len(data)

def close(self):
def close(self) -> None:
"""Close this decoder. If the underlying object has a `close()`
method, this function will call it.
"""
Expand Down
49 changes: 32 additions & 17 deletions multipart/multipart.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import io
import logging
import os
import shutil
Expand Down Expand Up @@ -534,14 +535,14 @@ def _get_disk_file(self):
self._actual_file_name = fname
return tmp_file

def write(self, data: bytes):
def write(self, data: bytes) -> int:
"""Write some data to the File.

:param data: a bytestring
"""
return self.on_data(data)

def on_data(self, data: bytes):
def on_data(self, data: bytes) -> int:
"""This method is a callback that will be called whenever data is
written to the File.

Expand Down Expand Up @@ -652,7 +653,7 @@ def callback(self, name: str, data: bytes | None = None, start: int | None = Non
self.logger.debug("Calling %s with no data", name)
func()

def set_callback(self, name: str, new_func):
def set_callback(self, name: str, new_func: Callable[..., Any] | None) -> None:
"""Update the function for a callback. Removes from the callbacks dict
if new_func is None.

Expand Down Expand Up @@ -1096,7 +1097,7 @@ def __init__(self, boundary: bytes | str, callbacks: MultipartCallbacks = {}, ma
# Note: the +8 is since we can have, at maximum, "\r\n--" + boundary +
# "--\r\n" at the final boundary, and the length of '\r\n--' and
# '--\r\n' is 8 bytes.
self.lookbehind = [NULL for x in range(len(boundary) + 8)]
self.lookbehind = [NULL for _ in range(len(boundary) + 8)]

def write(self, data: bytes) -> int:
"""Write some data to the parser, which will perform size verification,
Expand Down Expand Up @@ -1642,22 +1643,23 @@ def __init__(

# Depending on the Content-Type, we instantiate the correct parser.
if content_type == "application/octet-stream":
f: FileProtocol | None = None
file: FileProtocol = None # type: ignore

def on_start() -> None:
nonlocal f
f = FileClass(file_name, None, config=self.config)
nonlocal file
file = FileClass(file_name, None, config=self.config)

def on_data(data: bytes, start: int, end: int) -> None:
nonlocal f
f.write(data[start:end])
nonlocal file
file.write(data[start:end])

def _on_end() -> None:
nonlocal file
# Finalize the file itself.
f.finalize()
file.finalize()

# Call our callback.
on_file(f)
on_file(file)

# Call the on-end callback.
if self.on_end is not None:
Expand All @@ -1672,7 +1674,7 @@ def _on_end() -> None:
elif content_type == "application/x-www-form-urlencoded" or content_type == "application/x-url-encoded":
name_buffer: list[bytes] = []

f: FieldProtocol | None = None
f: FieldProtocol = None # type: ignore

def on_field_start() -> None:
pass
Expand Down Expand Up @@ -1747,13 +1749,13 @@ def on_part_end() -> None:
else:
on_field(f)

def on_header_field(data: bytes, start: int, end: int):
def on_header_field(data: bytes, start: int, end: int) -> None:
header_name.append(data[start:end])

def on_header_value(data: bytes, start: int, end: int):
def on_header_value(data: bytes, start: int, end: int) -> None:
header_value.append(data[start:end])

def on_header_end():
def on_header_end() -> None:
headers[b"".join(header_name)] = b"".join(header_value)
del header_name[:]
del header_value[:]
Expand Down Expand Up @@ -1855,7 +1857,13 @@ def __repr__(self) -> str:
return "{}(content_type={!r}, parser={!r})".format(self.__class__.__name__, self.content_type, self.parser)


def create_form_parser(headers, on_field, on_file, trust_x_headers=False, config={}):
def create_form_parser(
headers: dict[str, bytes],
on_field: OnFieldCallback,
on_file: OnFileCallback,
trust_x_headers: bool = False,
config={},
):
"""This function is a helper function to aid in creating a FormParser
instances. Given a dictionary-like headers object, it will determine
the correct information needed, instantiate a FormParser with the
Expand Down Expand Up @@ -1898,7 +1906,14 @@ def create_form_parser(headers, on_field, on_file, trust_x_headers=False, config
return form_parser


def parse_form(headers, input_stream, on_field, on_file, chunk_size=1048576, **kwargs):
def parse_form(
headers: dict[str, bytes],
input_stream: io.FileIO,
on_field: OnFieldCallback,
on_file: OnFileCallback,
chunk_size: int = 1048576,
**kwargs,
):
"""This function is useful if you just want to parse a request body,
without too much work. Pass it a dictionary-like object of the request's
headers, and a file-like object for the input stream, along with two
Expand Down
30 changes: 16 additions & 14 deletions tests/test_multipart.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import os
import random
import sys
Expand Down Expand Up @@ -288,19 +290,19 @@ def setUp(self):
self.b.callbacks = {}

def test_callbacks(self):
# The stupid list-ness is to get around lack of nonlocal on py2
l = [0]
called = 0

def on_foo():
l[0] += 1
nonlocal called
called += 1

self.b.set_callback("foo", on_foo)
self.b.callback("foo")
self.assertEqual(l[0], 1)
self.assertEqual(called, 1)

self.b.set_callback("foo", None)
self.b.callback("foo")
self.assertEqual(l[0], 1)
self.assertEqual(called, 1)


class TestQuerystringParser(unittest.TestCase):
Expand All @@ -316,15 +318,15 @@ def setUp(self):
self.reset()

def reset(self):
self.f = []
self.f: list[tuple[bytes, bytes]] = []

name_buffer = []
data_buffer = []
name_buffer: list[bytes] = []
data_buffer: list[bytes] = []

def on_field_name(data, start, end):
def on_field_name(data: bytes, start: int, end: int) -> None:
name_buffer.append(data[start:end])

def on_field_data(data, start, end):
def on_field_data(data: bytes, start: int, end: int) -> None:
data_buffer.append(data[start:end])

def on_field_end():
Expand Down Expand Up @@ -705,13 +707,13 @@ def split_all(val):
class TestFormParser(unittest.TestCase):
def make(self, boundary, config={}):
self.ended = False
self.files = []
self.fields = []
self.files: list[File] = []
self.fields: list[Field] = []

def on_field(f):
def on_field(f: Field) -> None:
self.fields.append(f)

def on_file(f):
def on_file(f: File) -> None:
self.files.append(f)

def on_end():
Expand Down