Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve type hints on FormParser #104

Merged
merged 1 commit into from
Feb 10, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 84 additions & 56 deletions multipart/multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .exceptions import FileError, FormParserError, MultipartParseError, QuerystringParseError

if TYPE_CHECKING: # pragma: no cover
from typing import Callable, TypedDict
from typing import Callable, Protocol, TypedDict

class QuerystringCallbacks(TypedDict, total=False):
on_field_start: Callable[[], None]
Expand Down Expand Up @@ -55,6 +55,30 @@ class FileConfig(TypedDict, total=False):
UPLOAD_KEEP_EXTENSIONS: bool
MAX_MEMORY_FILE_SIZE: int

class _FormProtocol(Protocol):
def write(self, data: bytes) -> int:
...

def finalize(self) -> None:
...

def close(self) -> None:
...

class FieldProtocol(_FormProtocol, Protocol):
def __init__(self, name: bytes) -> None:
...

def set_none(self) -> None:
...

class FileProtocol(_FormProtocol, Protocol):
def __init__(self, file_name: bytes | None, field_name: bytes | None, config: FileConfig) -> None:
...

OnFieldCallback = Callable[[FieldProtocol], None]
OnFileCallback = Callable[[FieldProtocol], None]


# Unique missing object.
_missing = object()
Expand Down Expand Up @@ -190,15 +214,15 @@ class Field:
:param name: the name of the form field
"""

def __init__(self, name: str):
def __init__(self, name: bytes):
self._name = name
self._value: list[bytes] = []

# We cache the joined version of _value for speed.
self._cache = _missing

@classmethod
def from_value(cls, name: str, value: bytes | None) -> Field:
def from_value(cls, name: bytes, value: bytes | None) -> Field:
"""Create an instance of a :class:`Field`, and set the corresponding
value - either None or an actual value. This method will also
finalize the Field itself.
Expand Down Expand Up @@ -260,7 +284,7 @@ def set_none(self) -> None:
self._cache = None

@property
def field_name(self) -> str:
def field_name(self) -> bytes:
"""This property returns the name of the field."""
return self._name

Expand Down Expand Up @@ -1562,6 +1586,7 @@ class FormParser:
field_instance.write(data)
field_instance.finalize()
field_instance.close()
field_instance.set_none()

:param config: Configuration to use for this FormParser. The default
values are taken from the DEFAULT_CONFIG value, and then
Expand All @@ -1584,14 +1609,14 @@ class FormParser:

def __init__(
self,
content_type,
on_field,
on_file,
on_end=None,
boundary=None,
file_name=None,
FileClass=File,
FieldClass=Field,
content_type: str,
on_field: OnFieldCallback,
on_file: OnFileCallback,
on_end: Callable[[], None] | None = None,
boundary: bytes | str | None = None,
file_name: bytes | None = None,
FileClass: type[FileProtocol] = File,
FieldClass: type[FieldProtocol] = Field,
config: FormParserConfig = {},
):
self.logger = logging.getLogger(__name__)
Expand All @@ -1617,38 +1642,37 @@ def __init__(

# Depending on the Content-Type, we instantiate the correct parser.
if content_type == "application/octet-stream":
# Work around the lack of 'nonlocal' in Py2
class vars:
f = None
f: FileProtocol | None = None

def on_start() -> None:
vars.f = FileClass(file_name, None, config=self.config)
nonlocal f
f = FileClass(file_name, None, config=self.config)

def on_data(data: bytes, start: int, end: int) -> None:
vars.f.write(data[start:end])
nonlocal f
f.write(data[start:end])

def on_end() -> None:
def _on_end() -> None:
# Finalize the file itself.
vars.f.finalize()
f.finalize()

# Call our callback.
on_file(vars.f)
on_file(f)

# Call the on-end callback.
if self.on_end is not None:
self.on_end()

# Instantiate an octet-stream parser
parser = OctetStreamParser(
callbacks={"on_start": on_start, "on_data": on_data, "on_end": on_end},
callbacks={"on_start": on_start, "on_data": on_data, "on_end": _on_end},
max_size=self.config["MAX_BODY_SIZE"],
)

elif content_type == "application/x-www-form-urlencoded" or content_type == "application/x-url-encoded":
name_buffer: list[bytes] = []

class vars:
f = None
f: FieldProtocol | None = None

def on_field_start() -> None:
pass
Expand All @@ -1657,25 +1681,27 @@ def on_field_name(data: bytes, start: int, end: int) -> None:
name_buffer.append(data[start:end])

def on_field_data(data: bytes, start: int, end: int) -> None:
if vars.f is None:
vars.f = FieldClass(b"".join(name_buffer))
nonlocal f
if f is None:
f = FieldClass(b"".join(name_buffer))
del name_buffer[:]
vars.f.write(data[start:end])
f.write(data[start:end])

def on_field_end() -> None:
nonlocal f
# Finalize and call callback.
if vars.f is None:
if f is None:
# If we get here, it's because there was no field data.
# We create a field, set it to None, and then continue.
vars.f = FieldClass(b"".join(name_buffer))
f = FieldClass(b"".join(name_buffer))
del name_buffer[:]
vars.f.set_none()
f.set_none()

vars.f.finalize()
on_field(vars.f)
vars.f = None
f.finalize()
on_field(f)
f = None

def on_end() -> None:
def _on_end() -> None:
if self.on_end is not None:
self.on_end()

Expand All @@ -1686,7 +1712,7 @@ def on_end() -> None:
"on_field_name": on_field_name,
"on_field_data": on_field_data,
"on_field_end": on_field_end,
"on_end": on_end,
"on_end": _on_end,
},
max_size=self.config["MAX_BODY_SIZE"],
)
Expand All @@ -1700,26 +1726,26 @@ def on_end() -> None:
header_value: list[bytes] = []
headers = {}

# No 'nonlocal' on Python 2 :-(
class vars:
f = None
writer = None
is_file = False
f: FileProtocol | FieldProtocol | None = None
writer = None
is_file = False

def on_part_begin():
pass

def on_part_data(data: bytes, start: int, end: int):
bytes_processed = vars.writer.write(data[start:end])
def on_part_data(data: bytes, start: int, end: int) -> None:
nonlocal writer
bytes_processed = writer.write(data[start:end])
# TODO: check for error here.
return bytes_processed

def on_part_end() -> None:
vars.f.finalize()
if vars.is_file:
on_file(vars.f)
nonlocal f, is_file
f.finalize()
if is_file:
on_file(f)
else:
on_field(vars.f)
on_field(f)

def on_header_field(data: bytes, start: int, end: int):
header_name.append(data[start:end])
Expand All @@ -1733,8 +1759,9 @@ def on_header_end():
del header_value[:]

def on_headers_finished() -> None:
nonlocal is_file, f, writer
# Reset the 'is file' flag.
vars.is_file = False
is_file = False

# Parse the content-disposition header.
# TODO: handle mixed case
Expand All @@ -1748,24 +1775,24 @@ def on_headers_finished() -> None:

# Create the proper class.
if file_name is None:
vars.f = FieldClass(field_name)
f = FieldClass(field_name)
else:
vars.f = FileClass(file_name, field_name, config=self.config)
vars.is_file = True
f = FileClass(file_name, field_name, config=self.config)
is_file = True

# Parse the given Content-Transfer-Encoding to determine what
# we need to do with the incoming data.
# TODO: check that we properly handle 8bit / 7bit encoding.
transfer_encoding = headers.get(b"Content-Transfer-Encoding", b"7bit")

if transfer_encoding == b"binary" or transfer_encoding == b"8bit" or transfer_encoding == b"7bit":
vars.writer = vars.f
writer = f

elif transfer_encoding == b"base64":
vars.writer = Base64Decoder(vars.f)
writer = Base64Decoder(f)

elif transfer_encoding == b"quoted-printable":
vars.writer = QuotedPrintableDecoder(vars.f)
writer = QuotedPrintableDecoder(f)

else:
self.logger.warning("Unknown Content-Transfer-Encoding: %r", transfer_encoding)
Expand All @@ -1774,10 +1801,11 @@ def on_headers_finished() -> None:
else:
# If we aren't erroring, then we just treat this as an
# unencoded Content-Transfer-Encoding.
vars.writer = vars.f
writer = f

def on_end() -> None:
vars.writer.finalize()
def _on_end() -> None:
nonlocal writer
writer.finalize()
if self.on_end is not None:
self.on_end()

Expand All @@ -1792,7 +1820,7 @@ def on_end() -> None:
"on_header_value": on_header_value,
"on_header_end": on_header_end,
"on_headers_finished": on_headers_finished,
"on_end": on_end,
"on_end": _on_end,
},
max_size=self.config["MAX_BODY_SIZE"],
)
Expand Down