From a169d93db50853105cf912653de91f5d4a790db2 Mon Sep 17 00:00:00 2001 From: John Stark Date: Sun, 29 Sep 2024 08:45:57 +0100 Subject: [PATCH] Add mypy strict typing (#140) * No errors with mypy --strict * Apply ruff formatting * Add py.typed file * Make it more modern * Add strict mode to mypy * Use --with instead of --from --------- Co-authored-by: Marcelo Trylesinski --- .github/workflows/main.yml | 2 +- .gitignore | 1 + multipart/decoders.py | 23 ++- multipart/multipart.py | 172 +++++++++++------- multipart/py.typed | 0 pyproject.toml | 5 + scripts/README.md | 8 + scripts/check | 9 + scripts/setup | 3 + tests/compat.py | 26 ++- tests/test_multipart.py | 351 +++++++++++++++++++------------------ uv.lock | 63 ++++++- 12 files changed, 412 insertions(+), 251 deletions(-) create mode 100644 multipart/py.typed create mode 100644 scripts/README.md create mode 100755 scripts/check create mode 100755 scripts/setup diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 9b5ed27..1881a56 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -28,7 +28,7 @@ jobs: run: uv sync --python ${{ matrix.python-version }} --frozen - name: Run linters - run: scripts/lint + run: scripts/check - name: Run tests run: scripts/test diff --git a/.gitignore b/.gitignore index f52a6b1..8c8a694 100644 --- a/.gitignore +++ b/.gitignore @@ -89,6 +89,7 @@ coverage.xml *.py,cover .hypothesis/ .pytest_cache/ +.ruff_cache/ cover/ # Translations diff --git a/multipart/decoders.py b/multipart/decoders.py index 135c56c..07bf742 100644 --- a/multipart/decoders.py +++ b/multipart/decoders.py @@ -1,9 +1,22 @@ import base64 import binascii -from io import BufferedWriter +from typing import TYPE_CHECKING from .exceptions import DecodeError +if TYPE_CHECKING: # pragma: no cover + from typing import Protocol, TypeVar + + _T_contra = TypeVar("_T_contra", contravariant=True) + + class SupportsWrite(Protocol[_T_contra]): + def write(self, __b: _T_contra) -> object: ... + + # No way to specify optional methods. See + # https://github.com/python/typing/issues/601 + # close() [Optional] + # finalize() [Optional] + class Base64Decoder: """This object provides an interface to decode a stream of Base64 data. It @@ -34,7 +47,7 @@ class Base64Decoder: :param underlying: the underlying object to pass writes to """ - def __init__(self, underlying: BufferedWriter): + def __init__(self, underlying: "SupportsWrite[bytes]") -> None: self.cache = bytearray() self.underlying = underlying @@ -67,9 +80,9 @@ def write(self, data: bytes) -> int: # Get the remaining bytes and save in our cache. remaining_len = len(data) % 4 if remaining_len > 0: - self.cache = data[-remaining_len:] + self.cache[:] = data[-remaining_len:] else: - self.cache = b"" + self.cache[:] = b"" # Return the length of the data to indicate no error. return len(data) @@ -112,7 +125,7 @@ class QuotedPrintableDecoder: :param underlying: the underlying object to pass writes to """ - def __init__(self, underlying: BufferedWriter) -> None: + def __init__(self, underlying: "SupportsWrite[bytes]") -> None: self.cache = b"" self.underlying = underlying diff --git a/multipart/multipart.py b/multipart/multipart.py index 18d0f1d..137d6e7 100644 --- a/multipart/multipart.py +++ b/multipart/multipart.py @@ -1,6 +1,5 @@ from __future__ import annotations -import io import logging import os import shutil @@ -8,15 +7,20 @@ import tempfile from email.message import Message from enum import IntEnum -from io import BytesIO +from io import BufferedRandom, BytesIO from numbers import Number -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, cast from .decoders import Base64Decoder, QuotedPrintableDecoder from .exceptions import FileError, FormParserError, MultipartParseError, QuerystringParseError if TYPE_CHECKING: # pragma: no cover - from typing import Callable, Protocol, TypedDict + from typing import Any, Callable, Literal, Protocol, TypedDict + + from typing_extensions import TypeAlias + + class SupportsRead(Protocol): + def read(self, __n: int) -> bytes: ... class QuerystringCallbacks(TypedDict, total=False): on_field_start: Callable[[], None] @@ -64,7 +68,7 @@ def finalize(self) -> None: ... def close(self) -> None: ... class FieldProtocol(_FormProtocol, Protocol): - def __init__(self, name: bytes) -> None: ... + def __init__(self, name: bytes | None) -> None: ... def set_none(self) -> None: ... @@ -74,6 +78,23 @@ def __init__(self, file_name: bytes | None, field_name: bytes | None, config: Fi OnFieldCallback = Callable[[FieldProtocol], None] OnFileCallback = Callable[[FileProtocol], None] + CallbackName: TypeAlias = Literal[ + "start", + "data", + "end", + "field_start", + "field_name", + "field_data", + "field_end", + "part_begin", + "part_data", + "part_end", + "header_begin", + "header_field", + "header_value", + "header_end", + "headers_finished", + ] # Unique missing object. _missing = object() @@ -142,7 +163,7 @@ class MultipartState(IntEnum): # fmt: on -def parse_options_header(value: str | bytes) -> tuple[bytes, dict[bytes, bytes]]: +def parse_options_header(value: str | bytes | None) -> tuple[bytes, dict[bytes, bytes]]: """Parses a Content-Type header into a value in the following format: (content_type, {parameters}).""" # Uses email.message.Message to parse the header as described in PEP 594. # Ref: https://peps.python.org/pep-0594/#cgi @@ -202,7 +223,7 @@ class Field: name: The name of the form field. """ - def __init__(self, name: bytes) -> None: + def __init__(self, name: bytes | None) -> None: self._name = name self._value: list[bytes] = [] @@ -283,7 +304,7 @@ def set_none(self) -> None: self._cache = None @property - def field_name(self) -> bytes: + def field_name(self) -> bytes | None: """This property returns the name of the field.""" return self._name @@ -293,6 +314,7 @@ def value(self) -> bytes | None: if self._cache is _missing: self._cache = b"".join(self._value) + assert isinstance(self._cache, bytes) or self._cache is None return self._cache def __eq__(self, other: object) -> bool: @@ -341,7 +363,7 @@ def __init__(self, file_name: bytes | None, field_name: bytes | None = None, con self._config = config self._in_memory = True self._bytes_written = 0 - self._fileobj = BytesIO() + self._fileobj: BytesIO | BufferedRandom = BytesIO() # Save the provided field/file name. self._field_name = field_name @@ -349,7 +371,7 @@ def __init__(self, file_name: bytes | None, field_name: bytes | None = None, con # Our actual file name is None by default, since, depending on our # config, we may not actually use the provided name. - self._actual_file_name = None + self._actual_file_name: bytes | None = None # Split the extension from the filename. if file_name is not None: @@ -370,14 +392,14 @@ def file_name(self) -> bytes | None: return self._file_name @property - def actual_file_name(self): + def actual_file_name(self) -> bytes | None: """The file name that this file is saved as. Will be None if it's not currently saved on disk. """ return self._actual_file_name @property - def file_object(self): + def file_object(self) -> BytesIO | BufferedRandom: """The file object that we're currently writing to. Note that this will either be an instance of a :class:`io.BytesIO`, or a regular file object. @@ -432,7 +454,7 @@ def flush_to_disk(self) -> None: # Close the old file object. old_fileobj.close() - def _get_disk_file(self) -> io.BufferedRandom | tempfile._TemporaryFileWrapper[bytes]: # type: ignore[reportPrivateUsage] + def _get_disk_file(self) -> BufferedRandom: """This function is responsible for getting a file object on-disk for us.""" self.logger.info("Opening a file on disk") @@ -440,6 +462,7 @@ def _get_disk_file(self) -> io.BufferedRandom | tempfile._TemporaryFileWrapper[b keep_filename = self._config.get("UPLOAD_KEEP_FILENAME", False) keep_extensions = self._config.get("UPLOAD_KEEP_EXTENSIONS", False) delete_tmp = self._config.get("UPLOAD_DELETE_TMP", True) + tmp_file: None | BufferedRandom = None # If we have a directory and are to keep the filename... if file_dir is not None and keep_filename: @@ -449,7 +472,7 @@ def _get_disk_file(self) -> io.BufferedRandom | tempfile._TemporaryFileWrapper[b # TODO: what happens if we don't have a filename? fname = self._file_base + self._ext if keep_extensions else self._file_base - path = os.path.join(file_dir, fname) + path = os.path.join(file_dir, fname) # type: ignore[arg-type] try: self.logger.info("Opening file: %r", path) tmp_file = open(path, "w+b") @@ -476,16 +499,17 @@ def _get_disk_file(self) -> io.BufferedRandom | tempfile._TemporaryFileWrapper[b "Creating a temporary file with options: %r", {"suffix": suffix, "delete": delete_tmp, "dir": dir} ) try: - tmp_file = tempfile.NamedTemporaryFile(suffix=suffix, delete=delete_tmp, dir=dir) + tmp_file = cast(BufferedRandom, tempfile.NamedTemporaryFile(suffix=suffix, delete=delete_tmp, dir=dir)) except OSError: self.logger.exception("Error creating named temporary file") raise FileError("Error creating named temporary file") - fname = tmp_file.name - + assert tmp_file is not None # Encode filename as bytes. - if isinstance(fname, str): - fname = fname.encode(sys.getfilesystemencoding()) + if isinstance(tmp_file.name, str): + fname = tmp_file.name.encode(sys.getfilesystemencoding()) + else: + fname = cast(bytes, tmp_file.name) # pragma: no cover self._actual_file_name = fname return tmp_file @@ -571,8 +595,11 @@ class BaseParser: def __init__(self) -> None: self.logger = logging.getLogger(__name__) + self.callbacks: QuerystringCallbacks | OctetStreamCallbacks | MultipartCallbacks = {} - def callback(self, name: str, data: bytes | None = None, start: int | None = None, end: int | None = None): + def callback( + self, name: CallbackName, data: bytes | None = None, start: int | None = None, end: int | None = None + ) -> None: """This function calls a provided callback with some data. If the callback is not set, will do nothing. @@ -583,24 +610,24 @@ def callback(self, name: str, data: bytes | None = None, start: int | None = Non end: An integer that is passed to the data callback. start: An integer that is passed to the data callback. """ - name = "on_" + name - func = self.callbacks.get(name) + on_name = "on_" + name + func = self.callbacks.get(on_name) if func is None: return - + func = cast("Callable[..., Any]", func) # Depending on whether we're given a buffer... if data is not None: # Don't do anything if we have start == end. if start is not None and start == end: return - self.logger.debug("Calling %s with data[%d:%d]", name, start, end) + self.logger.debug("Calling %s with data[%d:%d]", on_name, start, end) func(data, start, end) else: - self.logger.debug("Calling %s with no data", name) + self.logger.debug("Calling %s with no data", on_name) func() - def set_callback(self, name: str, new_func: Callable[..., Any] | None) -> None: + def set_callback(self, name: CallbackName, new_func: Callable[..., Any] | None) -> None: """Update the function for a callback. Removes from the callbacks dict if new_func is None. @@ -611,17 +638,17 @@ def set_callback(self, name: str, new_func: Callable[..., Any] | None) -> None: exist). """ if new_func is None: - self.callbacks.pop("on_" + name, None) + self.callbacks.pop("on_" + name, None) # type: ignore[misc] else: - self.callbacks["on_" + name] = new_func + self.callbacks["on_" + name] = new_func # type: ignore[literal-required] - def close(self): + def close(self) -> None: pass # pragma: no cover - def finalize(self): + def finalize(self) -> None: pass # pragma: no cover - def __repr__(self): + def __repr__(self) -> str: return "%s()" % self.__class__.__name__ @@ -647,7 +674,7 @@ def __init__(self, callbacks: OctetStreamCallbacks = {}, max_size: float = float if not isinstance(max_size, Number) or max_size < 1: raise ValueError("max_size must be a positive number, not %r" % max_size) - self.max_size = max_size + self.max_size: int | float = max_size self._current_size = 0 def write(self, data: bytes) -> int: @@ -729,7 +756,7 @@ def __init__( # Max-size stuff if not isinstance(max_size, Number) or max_size < 1: raise ValueError("max_size must be a positive number, not %r" % max_size) - self.max_size = max_size + self.max_size: int | float = max_size self._current_size = 0 # Should parsing be strict? @@ -1019,7 +1046,7 @@ def _internal_write(self, data: bytes, length: int) -> int: i = 0 # Set a mark. - def set_mark(name: str): + def set_mark(name: str) -> None: self.marks[name] = i # Remove a mark. @@ -1031,7 +1058,7 @@ def delete_mark(name: str, reset: bool = False) -> None: # end of the buffer, and reset the mark, instead of deleting it. This # is used at the end of the function to call our callbacks with any # remaining data in this chunk. - def data_callback(name: str, end_i: int, remaining: bool = False) -> None: + def data_callback(name: CallbackName, end_i: int, remaining: bool = False) -> None: marked_index = self.marks.get(name) if marked_index is None: return @@ -1471,8 +1498,8 @@ class if you wish to customize behaviour. The class will be instantiated as Fie def __init__( self, content_type: str, - on_field: OnFieldCallback, - on_file: OnFileCallback, + on_field: OnFieldCallback | None, + on_file: OnFileCallback | None, on_end: Callable[[], None] | None = None, boundary: bytes | str | None = None, file_name: bytes | None = None, @@ -1498,8 +1525,10 @@ def __init__( self.FieldClass = Field # Set configuration options. - self.config = self.DEFAULT_CONFIG.copy() - self.config.update(config) + self.config: FormParserConfig = self.DEFAULT_CONFIG.copy() + self.config.update(config) # type: ignore[typeddict-item] + + parser: OctetStreamParser | MultipartParser | QuerystringParser | None = None # Depending on the Content-Type, we instantiate the correct parser. if content_type == "application/octet-stream": @@ -1507,7 +1536,7 @@ def __init__( def on_start() -> None: nonlocal file - file = FileClass(file_name, None, config=self.config) + file = FileClass(file_name, None, config=cast("FileConfig", self.config)) def on_data(data: bytes, start: int, end: int) -> None: nonlocal file @@ -1519,7 +1548,8 @@ def _on_end() -> None: file.finalize() # Call our callback. - on_file(file) + if on_file: + on_file(file) # Call the on-end callback. if self.on_end is not None: @@ -1534,7 +1564,7 @@ def _on_end() -> None: elif content_type == "application/x-www-form-urlencoded" or content_type == "application/x-url-encoded": name_buffer: list[bytes] = [] - f: FieldProtocol = None # type: ignore + f: FieldProtocol | None = None def on_field_start() -> None: pass @@ -1560,7 +1590,8 @@ def on_field_end() -> None: f.set_none() f.finalize() - on_field(f) + if on_field: + on_field(f) f = None def _on_end() -> None: @@ -1586,30 +1617,33 @@ def _on_end() -> None: header_name: list[bytes] = [] header_value: list[bytes] = [] - headers = {} + headers: dict[bytes, bytes] = {} - f: FileProtocol | FieldProtocol | None = None + f_multi: FileProtocol | FieldProtocol | None = None writer = None is_file = False - def on_part_begin(): + def on_part_begin() -> None: # Reset headers in case this isn't the first part. nonlocal headers headers = {} def on_part_data(data: bytes, start: int, end: int) -> None: nonlocal writer - bytes_processed = writer.write(data[start:end]) + assert writer is not None + writer.write(data[start:end]) # TODO: check for error here. - return bytes_processed def on_part_end() -> None: - nonlocal f, is_file - f.finalize() + nonlocal f_multi, is_file + assert f_multi is not None + f_multi.finalize() if is_file: - on_file(f) + if on_file: + on_file(f_multi) else: - on_field(f) + if on_field: + on_field(cast("FieldProtocol", f_multi)) def on_header_field(data: bytes, start: int, end: int) -> None: header_name.append(data[start:end]) @@ -1623,7 +1657,7 @@ def on_header_end() -> None: del header_value[:] def on_headers_finished() -> None: - nonlocal is_file, f, writer + nonlocal is_file, f_multi, writer # Reset the 'is file' flag. is_file = False @@ -1639,9 +1673,9 @@ def on_headers_finished() -> None: # Create the proper class. if file_name is None: - f = FieldClass(field_name) + f_multi = FieldClass(field_name) else: - f = FileClass(file_name, field_name, config=self.config) + f_multi = FileClass(file_name, field_name, config=cast("FileConfig", self.config)) is_file = True # Parse the given Content-Transfer-Encoding to determine what @@ -1650,25 +1684,26 @@ def on_headers_finished() -> None: transfer_encoding = headers.get(b"Content-Transfer-Encoding", b"7bit") if transfer_encoding in (b"binary", b"8bit", b"7bit"): - writer = f + writer = f_multi elif transfer_encoding == b"base64": - writer = Base64Decoder(f) + writer = Base64Decoder(f_multi) elif transfer_encoding == b"quoted-printable": - writer = QuotedPrintableDecoder(f) + writer = QuotedPrintableDecoder(f_multi) else: self.logger.warning("Unknown Content-Transfer-Encoding: %r", transfer_encoding) if self.config["UPLOAD_ERROR_ON_BAD_CTE"]: - raise FormParserError('Unknown Content-Transfer-Encoding "{}"'.format(transfer_encoding)) + raise FormParserError('Unknown Content-Transfer-Encoding "{!r}"'.format(transfer_encoding)) else: # If we aren't erroring, then we just treat this as an # unencoded Content-Transfer-Encoding. - writer = f + writer = f_multi def _on_end() -> None: nonlocal writer + assert writer is not None writer.finalize() if self.on_end is not None: self.on_end() @@ -1707,6 +1742,7 @@ def write(self, data: bytes) -> int: """ self.bytes_received += len(data) # TODO: check the parser's return value for errors? + assert self.parser is not None return self.parser.write(data) def finalize(self) -> None: @@ -1725,8 +1761,8 @@ def __repr__(self) -> str: def create_form_parser( headers: dict[str, bytes], - on_field: OnFieldCallback, - on_file: OnFileCallback, + on_field: OnFieldCallback | None, + on_file: OnFileCallback | None, trust_x_headers: bool = False, config: dict[Any, Any] = {}, ) -> FormParser: @@ -1744,7 +1780,7 @@ def create_form_parser( name from X-File-Name. config: Configuration variables to pass to the FormParser. """ - content_type = headers.get("Content-Type") + content_type: str | bytes | None = headers.get("Content-Type") if content_type is None: logging.getLogger(__name__).warning("No Content-Type header given") raise ValueError("No Content-Type header given!") @@ -1769,9 +1805,9 @@ def create_form_parser( def parse_form( headers: dict[str, bytes], - input_stream: io.FileIO, - on_field: OnFieldCallback, - on_file: OnFileCallback, + input_stream: SupportsRead, + on_field: OnFieldCallback | None, + on_file: OnFileCallback | None, chunk_size: int = 1048576, ) -> None: """This function is useful if you just want to parse a request body, @@ -1792,7 +1828,7 @@ def parse_form( # Read chunks of 1MiB and write to the parser, but never read more than # the given Content-Length, if any. - content_length = headers.get("Content-Length") + content_length: int | float | bytes | None = headers.get("Content-Length") if content_length is not None: content_length = int(content_length) else: @@ -1801,7 +1837,7 @@ def parse_form( while True: # Read only up to the Content-Length given. - max_readable = min(content_length - bytes_read, chunk_size) + max_readable = int(min(content_length - bytes_read, chunk_size)) buff = input_stream.read(max_readable) # Write to the parser and update our length. diff --git a/multipart/py.typed b/multipart/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/pyproject.toml b/pyproject.toml index f672c70..fb03f83 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,8 @@ dev-dependencies = [ "invoke==2.2.0", "pytest-timeout==2.3.1", "ruff==0.3.4", + "mypy", + "types-PyYAML", "atheris==2.3.0; python_version != '3.12'", # Documentation "mkdocs", @@ -68,6 +70,9 @@ packages = ["multipart"] [tool.hatch.build.targets.sdist] include = ["/multipart", "/tests", "CHANGELOG.md", "LICENSE.txt"] +[tool.mypy] +strict = true + [tool.ruff] line-length = 120 diff --git a/scripts/README.md b/scripts/README.md new file mode 100644 index 0000000..1742ebd --- /dev/null +++ b/scripts/README.md @@ -0,0 +1,8 @@ +# Development Scripts + +* `scripts/setup` - Install dependencies. +* `scripts/test` - Run the test suite. +* `scripts/lint` - Run the code format. +* `scripts/check` - Run the lint in check mode, and the type checker. + +Styled after GitHub's ["Scripts to Rule Them All"](https://github.com/github/scripts-to-rule-them-all). diff --git a/scripts/check b/scripts/check new file mode 100755 index 0000000..0b6a294 --- /dev/null +++ b/scripts/check @@ -0,0 +1,9 @@ +#!/bin/sh -e + +set -x + +SOURCE_FILES="multipart tests" + +uvx ruff format --check --diff $SOURCE_FILES +uvx ruff check $SOURCE_FILES +uvx --with types-PyYAML mypy $SOURCE_FILES diff --git a/scripts/setup b/scripts/setup new file mode 100755 index 0000000..33797fc --- /dev/null +++ b/scripts/setup @@ -0,0 +1,3 @@ +#!/bin/sh -ex + +uv sync --frozen diff --git a/tests/compat.py b/tests/compat.py index 845a926..2253107 100644 --- a/tests/compat.py +++ b/tests/compat.py @@ -1,18 +1,24 @@ +from __future__ import annotations + import functools import os import re import sys import types +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Any, Callable -def ensure_in_path(path): +def ensure_in_path(path: str) -> None: """ Ensure that a given path is in the sys.path array """ if not os.path.isdir(path): raise RuntimeError("Tried to add nonexisting path") - def _samefile(x, y): + def _samefile(x: str, y: str) -> bool: try: return os.path.samefile(x, y) except OSError: @@ -34,7 +40,7 @@ def _samefile(x, y): # We don't use the pytest parametrizing function, since it seems to break # with unittest.TestCase subclasses. -def parametrize(field_names, field_values): +def parametrize(field_names: tuple[str] | list[str] | str, field_values: list[Any] | Any) -> Callable[..., Any]: # If we're not given a list of field names, we make it. if not isinstance(field_names, (tuple, list)): field_names = (field_names,) @@ -42,7 +48,7 @@ def parametrize(field_names, field_values): # Create a decorator that saves this list of field names and values on the # function for later parametrizing. - def decorator(func): + def decorator(func: Callable[..., Any]) -> Callable[..., Any]: func.__dict__["param_names"] = field_names func.__dict__["param_values"] = field_values return func @@ -54,7 +60,7 @@ def decorator(func): class ParametrizingMetaclass(type): IDENTIFIER_RE = re.compile("[^A-Za-z0-9]") - def __new__(klass, name, bases, attrs): + def __new__(klass, name: str, bases: tuple[type, ...], attrs: types.MappingProxyType[str, Any]) -> type: new_attrs = attrs.copy() for attr_name, attr in attrs.items(): # We only care about functions @@ -67,7 +73,7 @@ def __new__(klass, name, bases, attrs): continue # Create multiple copies of the function. - for i, values in enumerate(param_values): + for _, values in enumerate(param_values): assert len(param_names) == len(values) # Get a repr of the values, and fix it to be a valid identifier @@ -78,12 +84,14 @@ def __new__(klass, name, bases, attrs): new_name = attr.__name__ + "__" + human # Create a replacement function. - def create_new_func(func, names, values): + def create_new_func( + func: types.FunctionType, names: list[str], values: list[Any] + ) -> Callable[..., Any]: # Create a kwargs dictionary. kwargs = dict(zip(names, values)) @functools.wraps(func) - def new_func(self): + def new_func(self: types.FunctionType) -> Any: return func(self, **kwargs) # Manually set the name and return the new function. @@ -104,5 +112,5 @@ def new_func(self): # This is a class decorator that actually applies the above metaclass. -def parametrize_class(klass): +def parametrize_class(klass: type) -> ParametrizingMetaclass: return ParametrizingMetaclass(klass.__name__, klass.__bases__, klass.__dict__) diff --git a/tests/test_multipart.py b/tests/test_multipart.py index 2e22812..f55e228 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -6,13 +6,13 @@ import tempfile import unittest from io import BytesIO -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast from unittest.mock import Mock import yaml from multipart.decoders import Base64Decoder, QuotedPrintableDecoder -from multipart.exceptions import DecodeError, FileError, FormParserError, MultipartParseError +from multipart.exceptions import DecodeError, FileError, FormParserError, MultipartParseError, QuerystringParseError from multipart.multipart import ( BaseParser, Field, @@ -20,7 +20,6 @@ FormParser, MultipartParser, OctetStreamParser, - QuerystringParseError, QuerystringParser, create_form_parser, parse_form, @@ -30,13 +29,21 @@ from .compat import parametrize, parametrize_class if TYPE_CHECKING: - from multipart.multipart import FileConfig + from typing import Any, Iterator, TypedDict + + from multipart.multipart import FieldProtocol, FileConfig, FileProtocol + + class TestParams(TypedDict): + name: str + test: bytes + result: Any + # Get the current directory for our later test cases. curr_dir = os.path.abspath(os.path.dirname(__file__)) -def force_bytes(val): +def force_bytes(val: str | bytes) -> bytes: if isinstance(val, str): val = val.encode(sys.getfilesystemencoding()) @@ -44,33 +51,33 @@ def force_bytes(val): class TestField(unittest.TestCase): - def setUp(self): - self.f = Field("foo") + def setUp(self) -> None: + self.f = Field(b"foo") - def test_name(self): - self.assertEqual(self.f.field_name, "foo") + def test_name(self) -> None: + self.assertEqual(self.f.field_name, b"foo") - def test_data(self): + def test_data(self) -> None: self.f.write(b"test123") self.assertEqual(self.f.value, b"test123") - def test_cache_expiration(self): + def test_cache_expiration(self) -> None: self.f.write(b"test") self.assertEqual(self.f.value, b"test") self.f.write(b"123") self.assertEqual(self.f.value, b"test123") - def test_finalize(self): + def test_finalize(self) -> None: self.f.write(b"test123") self.f.finalize() self.assertEqual(self.f.value, b"test123") - def test_close(self): + def test_close(self) -> None: self.f.write(b"test123") self.f.close() self.assertEqual(self.f.value, b"test123") - def test_from_value(self): + def test_from_value(self) -> None: f = Field.from_value(b"name", b"value") self.assertEqual(f.field_name, b"name") self.assertEqual(f.value, b"value") @@ -78,18 +85,18 @@ def test_from_value(self): f2 = Field.from_value(b"name", None) self.assertEqual(f2.value, None) - def test_equality(self): + def test_equality(self) -> None: f1 = Field.from_value(b"name", b"value") f2 = Field.from_value(b"name", b"value") self.assertEqual(f1, f2) - def test_equality_with_other(self): + def test_equality_with_other(self) -> None: f = Field.from_value(b"foo", b"bar") self.assertFalse(f == b"foo") self.assertFalse(b"foo" == f) - def test_set_none(self): + def test_set_none(self) -> None: f = Field(b"foo") self.assertEqual(f.value, b"") @@ -98,34 +105,35 @@ def test_set_none(self): class TestFile(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.c: FileConfig = {} self.d = force_bytes(tempfile.mkdtemp()) self.f = File(b"foo.txt", config=self.c) - def assert_data(self, data): + def assert_data(self, data: bytes) -> None: f = self.f.file_object f.seek(0) self.assertEqual(f.read(), data) f.seek(0) f.truncate() - def assert_exists(self): + def assert_exists(self) -> None: + assert self.f.actual_file_name is not None full_path = os.path.join(self.d, self.f.actual_file_name) self.assertTrue(os.path.exists(full_path)) - def test_simple(self): + def test_simple(self) -> None: self.f.write(b"foobar") self.assert_data(b"foobar") - def test_invalid_write(self): + def test_invalid_write(self) -> None: m = Mock() m.write.return_value = 5 self.f._fileobj = m v = self.f.write(b"foobar") self.assertEqual(v, 5) - def test_file_fallback(self): + def test_file_fallback(self) -> None: self.c["MAX_MEMORY_FILE_SIZE"] = 1 self.f.write(b"1") @@ -142,7 +150,7 @@ def test_file_fallback(self): self.assertFalse(self.f.in_memory) self.assertIs(self.f.file_object, old_obj) - def test_file_fallback_with_data(self): + def test_file_fallback_with_data(self) -> None: self.c["MAX_MEMORY_FILE_SIZE"] = 10 self.f.write(b"1" * 10) @@ -153,7 +161,7 @@ def test_file_fallback_with_data(self): self.assert_data(b"11111111112222222222") - def test_file_name(self): + def test_file_name(self) -> None: # Write to this dir. self.c["UPLOAD_DIR"] = self.d self.c["MAX_MEMORY_FILE_SIZE"] = 10 @@ -166,7 +174,7 @@ def test_file_name(self): self.assertIsNotNone(self.f.actual_file_name) self.assert_exists() - def test_file_full_name(self): + def test_file_full_name(self) -> None: # Write to this dir. self.c["UPLOAD_DIR"] = self.d self.c["UPLOAD_KEEP_FILENAME"] = True @@ -180,7 +188,7 @@ def test_file_full_name(self): self.assertEqual(self.f.actual_file_name, b"foo") self.assert_exists() - def test_file_full_name_with_ext(self): + def test_file_full_name_with_ext(self) -> None: self.c["UPLOAD_DIR"] = self.d self.c["UPLOAD_KEEP_FILENAME"] = True self.c["UPLOAD_KEEP_EXTENSIONS"] = True @@ -194,7 +202,7 @@ def test_file_full_name_with_ext(self): self.assertEqual(self.f.actual_file_name, b"foo.txt") self.assert_exists() - def test_no_dir_with_extension(self): + def test_no_dir_with_extension(self) -> None: self.c["UPLOAD_KEEP_EXTENSIONS"] = True self.c["MAX_MEMORY_FILE_SIZE"] = 10 @@ -203,11 +211,12 @@ def test_no_dir_with_extension(self): self.assertFalse(self.f.in_memory) # Assert that the file exists + assert self.f.actual_file_name is not None ext = os.path.splitext(self.f.actual_file_name)[1] self.assertEqual(ext, b".txt") self.assert_exists() - def test_invalid_dir_with_name(self): + def test_invalid_dir_with_name(self) -> None: # Write to this dir. self.c["UPLOAD_DIR"] = force_bytes(os.path.join("/", "tmp", "notexisting")) self.c["UPLOAD_KEEP_FILENAME"] = True @@ -217,7 +226,7 @@ def test_invalid_dir_with_name(self): with self.assertRaises(FileError): self.f.write(b"1234567890") - def test_invalid_dir_no_name(self): + def test_invalid_dir_no_name(self) -> None: # Write to this dir. self.c["UPLOAD_DIR"] = force_bytes(os.path.join("/", "tmp", "notexisting")) self.c["UPLOAD_KEEP_FILENAME"] = False @@ -231,50 +240,50 @@ def test_invalid_dir_no_name(self): class TestParseOptionsHeader(unittest.TestCase): - def test_simple(self): + def test_simple(self) -> None: t, p = parse_options_header("application/json") self.assertEqual(t, b"application/json") self.assertEqual(p, {}) - def test_blank(self): + def test_blank(self) -> None: t, p = parse_options_header("") self.assertEqual(t, b"") self.assertEqual(p, {}) - def test_single_param(self): + def test_single_param(self) -> None: t, p = parse_options_header("application/json;par=val") self.assertEqual(t, b"application/json") self.assertEqual(p, {b"par": b"val"}) - def test_single_param_with_spaces(self): + def test_single_param_with_spaces(self) -> None: t, p = parse_options_header(b"application/json; par=val") self.assertEqual(t, b"application/json") self.assertEqual(p, {b"par": b"val"}) - def test_multiple_params(self): + def test_multiple_params(self) -> None: t, p = parse_options_header(b"application/json;par=val;asdf=foo") self.assertEqual(t, b"application/json") self.assertEqual(p, {b"par": b"val", b"asdf": b"foo"}) - def test_quoted_param(self): + def test_quoted_param(self) -> None: t, p = parse_options_header(b'application/json;param="quoted"') self.assertEqual(t, b"application/json") self.assertEqual(p, {b"param": b"quoted"}) - def test_quoted_param_with_semicolon(self): + def test_quoted_param_with_semicolon(self) -> None: t, p = parse_options_header(b'application/json;param="quoted;with;semicolons"') self.assertEqual(p[b"param"], b"quoted;with;semicolons") - def test_quoted_param_with_escapes(self): + def test_quoted_param_with_escapes(self) -> None: t, p = parse_options_header(b'application/json;param="This \\" is \\" a \\" quote"') self.assertEqual(p[b"param"], b'This " is " a " quote') - def test_handles_ie6_bug(self): + def test_handles_ie6_bug(self) -> None: t, p = parse_options_header(b'text/plain; filename="C:\\this\\is\\a\\path\\file.txt"') self.assertEqual(p[b"filename"], b"file.txt") - def test_redos_attack_header(self): + def test_redos_attack_header(self) -> None: t, p = parse_options_header( b'application/x-www-form-urlencoded; !="' b"\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\" @@ -282,47 +291,47 @@ def test_redos_attack_header(self): # If vulnerable, this test wouldn't finish, the line above would hang self.assertIn(b'"\\', p[b"!"]) - def test_handles_rfc_2231(self): + def test_handles_rfc_2231(self) -> None: t, p = parse_options_header(b"text/plain; param*=us-ascii'en-us'encoded%20message") self.assertEqual(p[b"param"], b"encoded message") class TestBaseParser(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.b = BaseParser() self.b.callbacks = {} - def test_callbacks(self): + def test_callbacks(self) -> None: called = 0 - def on_foo(): + def on_foo() -> None: nonlocal called called += 1 - self.b.set_callback("foo", on_foo) - self.b.callback("foo") + self.b.set_callback("foo", on_foo) # type: ignore[arg-type] + self.b.callback("foo") # type: ignore[arg-type] self.assertEqual(called, 1) - self.b.set_callback("foo", None) - self.b.callback("foo") + self.b.set_callback("foo", None) # type: ignore[arg-type] + self.b.callback("foo") # type: ignore[arg-type] self.assertEqual(called, 1) class TestQuerystringParser(unittest.TestCase): - def assert_fields(self, *args, **kwargs): + def assert_fields(self, *args: tuple[bytes, bytes], **kwargs: Any) -> None: if kwargs.pop("finalize", True): self.p.finalize() self.assertEqual(self.f, list(args)) if kwargs.get("reset", True): - self.f = [] + self.f: list[tuple[bytes, bytes]] = [] - def setUp(self): + def setUp(self) -> None: self.reset() - def reset(self): - self.f: list[tuple[bytes, bytes]] = [] + def reset(self) -> None: + self.f = [] name_buffer: list[bytes] = [] data_buffer: list[bytes] = [] @@ -333,7 +342,7 @@ def on_field_name(data: bytes, start: int, end: int) -> None: def on_field_data(data: bytes, start: int, end: int) -> None: data_buffer.append(data[start:end]) - def on_field_end(): + def on_field_end() -> None: self.f.append((b"".join(name_buffer), b"".join(data_buffer))) del name_buffer[:] @@ -343,34 +352,34 @@ def on_field_end(): callbacks={"on_field_name": on_field_name, "on_field_data": on_field_data, "on_field_end": on_field_end} ) - def test_simple_querystring(self): + def test_simple_querystring(self) -> None: self.p.write(b"foo=bar") self.assert_fields((b"foo", b"bar")) - def test_querystring_blank_beginning(self): + def test_querystring_blank_beginning(self) -> None: self.p.write(b"&foo=bar") self.assert_fields((b"foo", b"bar")) - def test_querystring_blank_end(self): + def test_querystring_blank_end(self) -> None: self.p.write(b"foo=bar&") self.assert_fields((b"foo", b"bar")) - def test_multiple_querystring(self): + def test_multiple_querystring(self) -> None: self.p.write(b"foo=bar&asdf=baz") self.assert_fields((b"foo", b"bar"), (b"asdf", b"baz")) - def test_streaming_simple(self): + def test_streaming_simple(self) -> None: self.p.write(b"foo=bar&") self.assert_fields((b"foo", b"bar"), finalize=False) self.p.write(b"asdf=baz") self.assert_fields((b"asdf", b"baz")) - def test_streaming_break(self): + def test_streaming_break(self) -> None: self.p.write(b"foo=one") self.assert_fields(finalize=False) @@ -386,12 +395,12 @@ def test_streaming_break(self): self.p.write(b"f=baz") self.assert_fields((b"asdf", b"baz")) - def test_semicolon_separator(self): + def test_semicolon_separator(self) -> None: self.p.write(b"foo=bar;asdf=baz") self.assert_fields((b"foo", b"bar"), (b"asdf", b"baz")) - def test_too_large_field(self): + def test_too_large_field(self) -> None: self.p.max_size = 15 # Note: len = 8 @@ -402,11 +411,11 @@ def test_too_large_field(self): self.p.write(b"a=123456") self.assert_fields((b"a", b"12345")) - def test_invalid_max_size(self): + def test_invalid_max_size(self) -> None: with self.assertRaises(ValueError): p = QuerystringParser(max_size=-100) - def test_strict_parsing_pass(self): + def test_strict_parsing_pass(self) -> None: data = b"foo=bar&another=asdf" for first, last in split_all(data): self.reset() @@ -418,7 +427,7 @@ def test_strict_parsing_pass(self): self.p.write(last) self.assert_fields((b"foo", b"bar"), (b"another", b"asdf")) - def test_strict_parsing_fail_double_sep(self): + def test_strict_parsing_fail_double_sep(self) -> None: data = b"foo=bar&&another=asdf" for first, last in split_all(data): self.reset() @@ -435,7 +444,7 @@ def test_strict_parsing_fail_double_sep(self): if cm is not None: self.assertEqual(cm.exception.offset, 8 - cnt) - def test_double_sep(self): + def test_double_sep(self) -> None: data = b"foo=bar&&another=asdf" for first, last in split_all(data): print(f" {first!r} / {last!r} ") @@ -447,7 +456,7 @@ def test_double_sep(self): self.assert_fields((b"foo", b"bar"), (b"another", b"asdf")) - def test_strict_parsing_fail_no_value(self): + def test_strict_parsing_fail_no_value(self) -> None: self.p.strict_parsing = True with self.assertRaises(QuerystringParseError) as cm: self.p.write(b"foo=bar&blank&another=asdf") @@ -455,18 +464,18 @@ def test_strict_parsing_fail_no_value(self): if cm is not None: self.assertEqual(cm.exception.offset, 8) - def test_success_no_value(self): + def test_success_no_value(self) -> None: self.p.write(b"foo=bar&blank&another=asdf") self.assert_fields((b"foo", b"bar"), (b"blank", b""), (b"another", b"asdf")) - def test_repr(self): + def test_repr(self) -> None: # Issue #29; verify we don't assert on repr() _ignored = repr(self.p) class TestOctetStreamParser(unittest.TestCase): - def setUp(self): - self.d = [] + def setUp(self) -> None: + self.d: list[bytes] = [] self.started = 0 self.finished = 0 @@ -481,23 +490,23 @@ def on_end() -> None: self.p = OctetStreamParser(callbacks={"on_start": on_start, "on_data": on_data, "on_end": on_end}) - def assert_data(self, data, finalize=True): + def assert_data(self, data: bytes, finalize: bool = True) -> None: self.assertEqual(b"".join(self.d), data) self.d = [] - def assert_started(self, val=True): + def assert_started(self, val: bool = True) -> None: if val: self.assertEqual(self.started, 1) else: self.assertEqual(self.started, 0) - def assert_finished(self, val=True): + def assert_finished(self, val: bool = True) -> None: if val: self.assertEqual(self.finished, 1) else: self.assertEqual(self.finished, 0) - def test_simple(self): + def test_simple(self) -> None: # Assert is not started self.assert_started(False) @@ -511,7 +520,7 @@ def test_simple(self): self.p.finalize() self.assert_finished() - def test_multiple_chunks(self): + def test_multiple_chunks(self) -> None: self.p.write(b"foo") self.p.write(b"bar") self.p.write(b"baz") @@ -520,7 +529,7 @@ def test_multiple_chunks(self): self.assert_data(b"foobarbaz") self.assert_finished() - def test_max_size(self): + def test_max_size(self) -> None: self.p.max_size = 5 self.p.write(b"0123456789") @@ -529,18 +538,18 @@ def test_max_size(self): self.assert_data(b"01234") self.assert_finished() - def test_invalid_max_size(self): + def test_invalid_max_size(self) -> None: with self.assertRaises(ValueError): - q = OctetStreamParser(max_size="foo") + q = OctetStreamParser(max_size="foo") # type: ignore[arg-type] class TestBase64Decoder(unittest.TestCase): # Note: base64('foobar') == 'Zm9vYmFy' - def setUp(self): + def setUp(self) -> None: self.f = BytesIO() self.d = Base64Decoder(self.f) - def assert_data(self, data, finalize=True): + def assert_data(self, data: bytes, finalize: bool = True) -> None: if finalize: self.d.finalize() @@ -549,20 +558,20 @@ def assert_data(self, data, finalize=True): self.f.seek(0) self.f.truncate() - def test_simple(self): + def test_simple(self) -> None: self.d.write(b"Zm9vYmFy") self.assert_data(b"foobar") - def test_bad(self): + def test_bad(self) -> None: with self.assertRaises(DecodeError): self.d.write(b"Zm9v!mFy") - def test_split_properly(self): + def test_split_properly(self) -> None: self.d.write(b"Zm9v") self.d.write(b"YmFy") self.assert_data(b"foobar") - def test_bad_split(self): + def test_bad_split(self) -> None: buff = b"Zm9v" for i in range(1, 4): first, second = buff[:i], buff[i:] @@ -572,7 +581,7 @@ def test_bad_split(self): self.d.write(second) self.assert_data(b"foo") - def test_long_bad_split(self): + def test_long_bad_split(self) -> None: buff = b"Zm9vYmFy" for i in range(5, 8): first, second = buff[:i], buff[i:] @@ -582,7 +591,7 @@ def test_long_bad_split(self): self.d.write(second) self.assert_data(b"foobar") - def test_close_and_finalize(self): + def test_close_and_finalize(self) -> None: parser = Mock() f = Base64Decoder(parser) @@ -592,7 +601,7 @@ def test_close_and_finalize(self): f.close() parser.close.assert_called_once_with() - def test_bad_length(self): + def test_bad_length(self) -> None: self.d.write(b"Zm9vYmF") # missing ending 'y' with self.assertRaises(DecodeError): @@ -600,11 +609,11 @@ def test_bad_length(self): class TestQuotedPrintableDecoder(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.f = BytesIO() self.d = QuotedPrintableDecoder(self.f) - def assert_data(self, data, finalize=True): + def assert_data(self, data: bytes, finalize: bool = True) -> None: if finalize: self.d.finalize() @@ -613,38 +622,38 @@ def assert_data(self, data, finalize=True): self.f.seek(0) self.f.truncate() - def test_simple(self): + def test_simple(self) -> None: self.d.write(b"foobar") self.assert_data(b"foobar") - def test_with_escape(self): + def test_with_escape(self) -> None: self.d.write(b"foo=3Dbar") self.assert_data(b"foo=bar") - def test_with_newline_escape(self): + def test_with_newline_escape(self) -> None: self.d.write(b"foo=\r\nbar") self.assert_data(b"foobar") - def test_with_only_newline_escape(self): + def test_with_only_newline_escape(self) -> None: self.d.write(b"foo=\nbar") self.assert_data(b"foobar") - def test_with_split_escape(self): + def test_with_split_escape(self) -> None: self.d.write(b"foo=3") self.d.write(b"Dbar") self.assert_data(b"foo=bar") - def test_with_split_newline_escape_1(self): + def test_with_split_newline_escape_1(self) -> None: self.d.write(b"foo=\r") self.d.write(b"\nbar") self.assert_data(b"foobar") - def test_with_split_newline_escape_2(self): + def test_with_split_newline_escape_2(self) -> None: self.d.write(b"foo=") self.d.write(b"\r\nbar") self.assert_data(b"foobar") - def test_close_and_finalize(self): + def test_close_and_finalize(self) -> None: parser = Mock() f = QuotedPrintableDecoder(parser) @@ -654,7 +663,7 @@ def test_close_and_finalize(self): f.close() parser.close.assert_called_once_with() - def test_not_aligned(self): + def test_not_aligned(self) -> None: """ https://github.com/andrew-d/python-multipart/issues/6 """ @@ -675,7 +684,7 @@ def test_not_aligned(self): # Read in all test cases and load them. NON_PARAMETRIZED_TESTS = {"single_field_blocks"} -http_tests = [] +http_tests: list[TestParams] = [] for f in os.listdir(http_tests_dir): # Only load the HTTP test cases. fname, ext = os.path.splitext(f) @@ -687,11 +696,11 @@ def test_not_aligned(self): yaml_file = os.path.join(http_tests_dir, fname + ".yaml") # Load both. - with open(os.path.join(http_tests_dir, f), "rb") as f: - test_data = f.read() + with open(os.path.join(http_tests_dir, f), "rb") as fh: + test_data = fh.read() - with open(yaml_file, "rb") as f: - yaml_data = yaml.safe_load(f) + with open(yaml_file, "rb") as fy: + yaml_data = yaml.safe_load(fy) http_tests.append({"name": fname, "test": test_data, "result": yaml_data}) @@ -704,7 +713,8 @@ def test_not_aligned(self): "single_field_single_file", ] -def split_all(val): + +def split_all(val: bytes) -> Iterator[tuple[bytes, bytes]]: """ This function will split an array all possible ways. For example: split_all([1,2,3,4]) @@ -717,30 +727,30 @@ def split_all(val): @parametrize_class class TestFormParser(unittest.TestCase): - def make(self, boundary, config={}): + def make(self, boundary: str | bytes, config: dict[str, Any] = {}) -> None: self.ended = False self.files: list[File] = [] self.fields: list[Field] = [] - def on_field(f: Field) -> None: - self.fields.append(f) + def on_field(f: FieldProtocol) -> None: + self.fields.append(cast(Field, f)) - def on_file(f: File) -> None: - self.files.append(f) + def on_file(f: FileProtocol) -> None: + self.files.append(cast(File, f)) - def on_end(): + def on_end() -> None: self.ended = True # Get a form-parser instance. self.f = FormParser("multipart/form-data", on_field, on_file, on_end, boundary=boundary, config=config) - def assert_file_data(self, f, data): + def assert_file_data(self, f: File, data: bytes) -> None: o = f.file_object o.seek(0) file_data = o.read() self.assertEqual(file_data, data) - def assert_file(self, field_name, file_name, data): + def assert_file(self, field_name: bytes, file_name: bytes, data: bytes) -> None: # Find this file. found = None for f in self.files: @@ -750,6 +760,7 @@ def assert_file(self, field_name, file_name, data): # Assert that we found it. self.assertIsNotNone(found) + assert found is not None try: # Assert about this file. @@ -762,7 +773,7 @@ def assert_file(self, field_name, file_name, data): # Close our file found.close() - def assert_field(self, name, value): + def assert_field(self, name: bytes, value: bytes) -> None: # Find this field in our fields list. found = None for f in self.fields: @@ -772,13 +783,14 @@ def assert_field(self, name, value): # Assert that it exists and matches. self.assertIsNotNone(found) + assert found is not None # typing self.assertEqual(value, found.value) # Remove it for future iterations. self.fields.remove(found) @parametrize("param", http_tests) - def test_http(self, param): + def test_http(self, param: TestParams) -> None: # Firstly, create our parser with the given boundary. boundary = param["result"]["boundary"] if isinstance(boundary, str): @@ -790,9 +802,9 @@ def test_http(self, param): try: processed = self.f.write(param["test"]) self.f.finalize() - except MultipartParseError as e: + except MultipartParseError as err: processed = 0 - exc = e + exc = err # print(repr(param)) # print("") @@ -802,6 +814,7 @@ def test_http(self, param): # Do we expect an error? if "error" in param["result"]["expected"]: self.assertIsNotNone(exc) + assert exc is not None self.assertEqual(param["result"]["expected"]["error"], exc.offset) return @@ -823,7 +836,7 @@ def test_http(self, param): else: assert False - def test_random_splitting(self): + def test_random_splitting(self) -> None: """ This test runs a simple multipart body with one field and one file through every possible split. @@ -851,8 +864,8 @@ def test_random_splitting(self): self.assert_field(b"field", b"test1") self.assert_file(b"file", b"file.txt", b"test2") - @parametrize("param", [ t for t in http_tests if t["name"] in single_byte_tests]) - def test_feed_single_bytes(self, param): + @parametrize("param", [t for t in http_tests if t["name"] in single_byte_tests]) + def test_feed_single_bytes(self, param: TestParams) -> None: """ This test parses multipart bodies 1 byte at a time. """ @@ -893,7 +906,7 @@ def test_feed_single_bytes(self, param): else: assert False - def test_feed_blocks(self): + def test_feed_blocks(self) -> None: """ This test parses a simple multipart body 1 byte at a time. """ @@ -926,7 +939,7 @@ def test_feed_blocks(self): # Assert that our field is here. self.assert_field(b"field", b"0123456789ABCDEFGHIJ0123456789ABCDEFGHIJ") - def test_request_body_fuzz(self): + def test_request_body_fuzz(self) -> None: """ This test randomly fuzzes the request body to ensure that no strange exceptions are raised and we don't end up in a strange state. The @@ -998,7 +1011,7 @@ def test_request_body_fuzz(self): print("Failures: %d" % (failures,)) print("Exceptions: %d" % (exceptions,)) - def test_request_body_fuzz_random_data(self): + def test_request_body_fuzz_random_data(self) -> None: """ This test will fuzz the multipart parser with some number of iterations of randomly-generated data. @@ -1035,7 +1048,7 @@ def test_request_body_fuzz_random_data(self): print("Failures: %d" % (failures,)) print("Exceptions: %d" % (exceptions,)) - def test_bad_start_boundary(self): + def test_bad_start_boundary(self) -> None: self.make("boundary") data = b"--boundary\rfoobar" with self.assertRaises(MultipartParseError): @@ -1046,11 +1059,11 @@ def test_bad_start_boundary(self): with self.assertRaises(MultipartParseError): i = self.f.write(data) - def test_octet_stream(self): - files = [] + def test_octet_stream(self) -> None: + files: list[File] = [] - def on_file(f): - files.append(f) + def on_file(f: FileProtocol) -> None: + files.append(cast(File, f)) on_field = Mock() on_end = Mock() @@ -1068,16 +1081,16 @@ def on_file(f): self.assert_file_data(files[0], b"test1234") self.assertTrue(on_end.called) - def test_querystring(self): - fields = [] + def test_querystring(self) -> None: + fields: list[Field] = [] - def on_field(f): - fields.append(f) + def on_field(f: FieldProtocol) -> None: + fields.append(cast(Field, f)) on_file = Mock() on_end = Mock() - def simple_test(f): + def simple_test(f: FormParser) -> None: # Reset tracking. del fields[:] on_file.reset_mock() @@ -1110,7 +1123,7 @@ def simple_test(f): self.assertTrue(isinstance(f.parser, QuerystringParser)) simple_test(f) - def test_close_methods(self): + def test_close_methods(self) -> None: parser = Mock() f = FormParser("application/x-url-encoded", None, None) f.parser = parser @@ -1121,18 +1134,18 @@ def test_close_methods(self): f.close() parser.close.assert_called_once_with() - def test_bad_content_type(self): + def test_bad_content_type(self) -> None: # We should raise a ValueError for a bad Content-Type with self.assertRaises(ValueError): f = FormParser("application/bad", None, None) - def test_no_boundary_given(self): + def test_no_boundary_given(self) -> None: # We should raise a FormParserError when parsing a multipart message # without a boundary. with self.assertRaises(FormParserError): f = FormParser("multipart/form-data", None, None) - def test_bad_content_transfer_encoding(self): + def test_bad_content_transfer_encoding(self) -> None: data = ( b'----boundary\r\nContent-Disposition: form-data; name="file"; filename="test.txt"\r\n' b"Content-Type: text/plain\r\n" @@ -1140,10 +1153,10 @@ def test_bad_content_transfer_encoding(self): b"Test\r\n----boundary--\r\n" ) - files = [] + files: list[File] = [] - def on_file(f): - files.append(f) + def on_file(f: FileProtocol) -> None: + files.append(cast(File, f)) on_field = Mock() on_end = Mock() @@ -1164,11 +1177,11 @@ def on_file(f): f.finalize() self.assert_file_data(files[0], b"Test") - def test_handles_None_fields(self): - fields = [] + def test_handles_None_fields(self) -> None: + fields: list[Field] = [] - def on_field(f): - fields.append(f) + def on_field(f: FieldProtocol) -> None: + fields.append(cast(Field, f)) on_file = Mock() on_end = Mock() @@ -1186,7 +1199,7 @@ def on_field(f): self.assertEqual(fields[2].field_name, b"baz") self.assertEqual(fields[2].value, b"asdf") - def test_max_size_multipart(self): + def test_max_size_multipart(self) -> None: # Load test data. test_file = "single_field_single_file.http" with open(os.path.join(http_tests_dir, test_file), "rb") as f: @@ -1197,7 +1210,8 @@ def test_max_size_multipart(self): # Set the maximum length that we can process to be halfway through the # given data. - self.f.parser.max_size = len(test_data) / 2 + assert self.f.parser is not None + self.f.parser.max_size = float(len(test_data)) / 2 i = self.f.write(test_data) self.f.finalize() @@ -1205,7 +1219,7 @@ def test_max_size_multipart(self): # Assert we processed the correct amount. self.assertEqual(i, len(test_data) / 2) - def test_max_size_form_parser(self): + def test_max_size_form_parser(self) -> None: # Load test data. test_file = "single_field_single_file.http" with open(os.path.join(http_tests_dir, test_file), "rb") as f: @@ -1222,11 +1236,11 @@ def test_max_size_form_parser(self): # Assert we processed the correct amount. self.assertEqual(i, len(test_data) / 2) - def test_octet_stream_max_size(self): - files = [] + def test_octet_stream_max_size(self) -> None: + files: list[File] = [] - def on_file(f): - files.append(f) + def on_file(f: FileProtocol) -> None: + files.append(cast(File, f)) on_field = Mock() on_end = Mock() @@ -1245,11 +1259,11 @@ def on_file(f): self.assert_file_data(files[0], b"0123456789") - def test_invalid_max_size_multipart(self): + def test_invalid_max_size_multipart(self) -> None: with self.assertRaises(ValueError): - MultipartParser(b"bound", max_size="foo") + MultipartParser(b"bound", max_size="foo") # type: ignore[arg-type] - def test_header_begin_callback(self): + def test_header_begin_callback(self) -> None: """ This test verifies we call the `on_header_begin` callback. See GitHub issue #23 @@ -1280,20 +1294,20 @@ def on_header_begin() -> None: class TestHelperFunctions(unittest.TestCase): - def test_create_form_parser(self): - r = create_form_parser({"Content-Type": "application/octet-stream"}, None, None) + def test_create_form_parser(self) -> None: + r = create_form_parser({"Content-Type": b"application/octet-stream"}, None, None) self.assertTrue(isinstance(r, FormParser)) - def test_create_form_parser_error(self): - headers = {} + def test_create_form_parser_error(self) -> None: + headers: dict[str, bytes] = {} with self.assertRaises(ValueError): create_form_parser(headers, None, None) - def test_parse_form(self): + def test_parse_form(self) -> None: on_field = Mock() on_file = Mock() - parse_form({"Content-Type": "application/octet-stream"}, BytesIO(b"123456789012345"), on_field, on_file) + parse_form({"Content-Type": b"application/octet-stream"}, BytesIO(b"123456789012345"), on_field, on_file) assert on_file.call_count == 1 @@ -1301,24 +1315,27 @@ def test_parse_form(self): # 15 - i.e. all data is written. self.assertEqual(on_file.call_args[0][0].size, 15) - def test_parse_form_content_length(self): - files = [] + def test_parse_form_content_length(self) -> None: + files: list[FileProtocol] = [] - def on_file(file): + def on_field(field: FieldProtocol) -> None: + pass + + def on_file(file: FileProtocol) -> None: files.append(file) parse_form( - {"Content-Type": "application/octet-stream", "Content-Length": "10"}, + {"Content-Type": b"application/octet-stream", "Content-Length": b"10"}, BytesIO(b"123456789012345"), - None, + on_field, on_file, ) self.assertEqual(len(files), 1) - self.assertEqual(files[0].size, 10) + self.assertEqual(files[0].size, 10) # type: ignore[attr-defined] -def suite(): +def suite() -> unittest.TestSuite: suite = unittest.TestSuite() suite.addTest(unittest.defaultTestLoader.loadTestsFromTestCase(TestFile)) suite.addTest(unittest.defaultTestLoader.loadTestsFromTestCase(TestParseOptionsHeader)) diff --git a/uv.lock b/uv.lock index 69f3835..2ae1c4e 100644 --- a/uv.lock +++ b/uv.lock @@ -524,6 +524,54 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/50/e2/8e10e465ee3987bb7c9ab69efb91d867d93959095f4807db102d07995d94/more_itertools-10.2.0-py3-none-any.whl", hash = "sha256:686b06abe565edfab151cb8fd385a05651e1fdf8f0a14191e4439283421f8684", size = 57015 }, ] +[[package]] +name = "mypy" +version = "1.11.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mypy-extensions" }, + { name = "tomli", marker = "python_full_version < '3.11'" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5c/86/5d7cbc4974fd564550b80fbb8103c05501ea11aa7835edf3351d90095896/mypy-1.11.2.tar.gz", hash = "sha256:7f9993ad3e0ffdc95c2a14b66dee63729f021968bff8ad911867579c65d13a79", size = 3078806 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/78/cd/815368cd83c3a31873e5e55b317551500b12f2d1d7549720632f32630333/mypy-1.11.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d42a6dd818ffce7be66cce644f1dff482f1d97c53ca70908dff0b9ddc120b77a", size = 10939401 }, + { url = "https://files.pythonhosted.org/packages/f1/27/e18c93a195d2fad75eb96e1f1cbc431842c332e8eba2e2b77eaf7313c6b7/mypy-1.11.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:801780c56d1cdb896eacd5619a83e427ce436d86a3bdf9112527f24a66618fef", size = 10111697 }, + { url = "https://files.pythonhosted.org/packages/dc/08/cdc1fc6d0d5a67d354741344cc4aa7d53f7128902ebcbe699ddd4f15a61c/mypy-1.11.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:41ea707d036a5307ac674ea172875f40c9d55c5394f888b168033177fce47383", size = 12500508 }, + { url = "https://files.pythonhosted.org/packages/64/12/aad3af008c92c2d5d0720ea3b6674ba94a98cdb86888d389acdb5f218c30/mypy-1.11.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6e658bd2d20565ea86da7d91331b0eed6d2eee22dc031579e6297f3e12c758c8", size = 13020712 }, + { url = "https://files.pythonhosted.org/packages/03/e6/a7d97cc124a565be5e9b7d5c2a6ebf082379ffba99646e4863ed5bbcb3c3/mypy-1.11.2-cp310-cp310-win_amd64.whl", hash = "sha256:478db5f5036817fe45adb7332d927daa62417159d49783041338921dcf646fc7", size = 9567319 }, + { url = "https://files.pythonhosted.org/packages/e2/aa/cc56fb53ebe14c64f1fe91d32d838d6f4db948b9494e200d2f61b820b85d/mypy-1.11.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:75746e06d5fa1e91bfd5432448d00d34593b52e7e91a187d981d08d1f33d4385", size = 10859630 }, + { url = "https://files.pythonhosted.org/packages/04/c8/b19a760fab491c22c51975cf74e3d253b8c8ce2be7afaa2490fbf95a8c59/mypy-1.11.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a976775ab2256aadc6add633d44f100a2517d2388906ec4f13231fafbb0eccca", size = 10037973 }, + { url = "https://files.pythonhosted.org/packages/88/57/7e7e39f2619c8f74a22efb9a4c4eff32b09d3798335625a124436d121d89/mypy-1.11.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:cd953f221ac1379050a8a646585a29574488974f79d8082cedef62744f0a0104", size = 12416659 }, + { url = "https://files.pythonhosted.org/packages/fc/a6/37f7544666b63a27e46c48f49caeee388bf3ce95f9c570eb5cfba5234405/mypy-1.11.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:57555a7715c0a34421013144a33d280e73c08df70f3a18a552938587ce9274f4", size = 12897010 }, + { url = "https://files.pythonhosted.org/packages/84/8b/459a513badc4d34acb31c736a0101c22d2bd0697b969796ad93294165cfb/mypy-1.11.2-cp311-cp311-win_amd64.whl", hash = "sha256:36383a4fcbad95f2657642a07ba22ff797de26277158f1cc7bd234821468b1b6", size = 9562873 }, + { url = "https://files.pythonhosted.org/packages/35/3a/ed7b12ecc3f6db2f664ccf85cb2e004d3e90bec928e9d7be6aa2f16b7cdf/mypy-1.11.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:e8960dbbbf36906c5c0b7f4fbf2f0c7ffb20f4898e6a879fcf56a41a08b0d318", size = 10990335 }, + { url = "https://files.pythonhosted.org/packages/04/e4/1a9051e2ef10296d206519f1df13d2cc896aea39e8683302f89bf5792a59/mypy-1.11.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:06d26c277962f3fb50e13044674aa10553981ae514288cb7d0a738f495550b36", size = 10007119 }, + { url = "https://files.pythonhosted.org/packages/f3/3c/350a9da895f8a7e87ade0028b962be0252d152e0c2fbaafa6f0658b4d0d4/mypy-1.11.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6e7184632d89d677973a14d00ae4d03214c8bc301ceefcdaf5c474866814c987", size = 12506856 }, + { url = "https://files.pythonhosted.org/packages/b6/49/ee5adf6a49ff13f4202d949544d3d08abb0ea1f3e7f2a6d5b4c10ba0360a/mypy-1.11.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:3a66169b92452f72117e2da3a576087025449018afc2d8e9bfe5ffab865709ca", size = 12952066 }, + { url = "https://files.pythonhosted.org/packages/27/c0/b19d709a42b24004d720db37446a42abadf844d5c46a2c442e2a074d70d9/mypy-1.11.2-cp312-cp312-win_amd64.whl", hash = "sha256:969ea3ef09617aff826885a22ece0ddef69d95852cdad2f60c8bb06bf1f71f70", size = 9664000 }, + { url = "https://files.pythonhosted.org/packages/42/ad/5a8567700410f8aa7c755b0ebd4cacff22468cbc5517588773d65075c0cb/mypy-1.11.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:37c7fa6121c1cdfcaac97ce3d3b5588e847aa79b580c1e922bb5d5d2902df19b", size = 10876550 }, + { url = "https://files.pythonhosted.org/packages/1b/bc/9fc16ea7a27ceb93e123d300f1cfe27a6dd1eac9a8beea4f4d401e737e9d/mypy-1.11.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4a8a53bc3ffbd161b5b2a4fff2f0f1e23a33b0168f1c0778ec70e1a3d66deb86", size = 10068086 }, + { url = "https://files.pythonhosted.org/packages/cd/8f/a1e460f1288405a13352dad16b24aba6dce4f850fc76510c540faa96eda3/mypy-1.11.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2ff93107f01968ed834f4256bc1fc4475e2fecf6c661260066a985b52741ddce", size = 12459214 }, + { url = "https://files.pythonhosted.org/packages/c7/74/746b31aef7cc7512dab8bdc2311ef88d63fadc1c453a09c8cab7e57e59bf/mypy-1.11.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:edb91dded4df17eae4537668b23f0ff6baf3707683734b6a818d5b9d0c0c31a1", size = 12962942 }, + { url = "https://files.pythonhosted.org/packages/28/a4/7fae712240b640d75bb859294ad4776b9960b3216ccb7fa747f578e6c632/mypy-1.11.2-cp38-cp38-win_amd64.whl", hash = "sha256:ee23de8530d99b6db0573c4ef4bd8f39a2a6f9b60655bf7a1357e585a3486f2b", size = 9545616 }, + { url = "https://files.pythonhosted.org/packages/16/64/bb5ed751487e2bea0dfaa6f640a7e3bb88083648f522e766d5ef4a76f578/mypy-1.11.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:801ca29f43d5acce85f8e999b1e431fb479cb02d0e11deb7d2abb56bdaf24fd6", size = 10937294 }, + { url = "https://files.pythonhosted.org/packages/a9/a3/67a0069abed93c3bf3b0bebb8857e2979a02828a4a3fd82f107f8f1143e8/mypy-1.11.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:af8d155170fcf87a2afb55b35dc1a0ac21df4431e7d96717621962e4b9192e70", size = 10107707 }, + { url = "https://files.pythonhosted.org/packages/2f/4d/0379daf4258b454b1f9ed589a9dabd072c17f97496daea7b72fdacf7c248/mypy-1.11.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f7821776e5c4286b6a13138cc935e2e9b6fde05e081bdebf5cdb2bb97c9df81d", size = 12498367 }, + { url = "https://files.pythonhosted.org/packages/3b/dc/3976a988c280b3571b8eb6928882dc4b723a403b21735a6d8ae6ed20e82b/mypy-1.11.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:539c570477a96a4e6fb718b8d5c3e0c0eba1f485df13f86d2970c91f0673148d", size = 13018014 }, + { url = "https://files.pythonhosted.org/packages/83/84/adffc7138fb970e7e2a167bd20b33bb78958370179853a4ebe9008139342/mypy-1.11.2-cp39-cp39-win_amd64.whl", hash = "sha256:3f14cd3d386ac4d05c5a39a51b84387403dadbd936e17cb35882134d4f8f0d24", size = 9568056 }, + { url = "https://files.pythonhosted.org/packages/42/3a/bdf730640ac523229dd6578e8a581795720a9321399de494374afc437ec5/mypy-1.11.2-py3-none-any.whl", hash = "sha256:b499bc07dbdcd3de92b0a8b29fdf592c111276f6a12fe29c30f6c417dd546d12", size = 2619625 }, +] + +[[package]] +name = "mypy-extensions" +version = "1.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/98/a4/1ab47638b92648243faf97a5aeb6ea83059cc3624972ab6b8d2316078d3f/mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782", size = 4433 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2a/e2/5d3f6ada4297caebe1a2add3b126fe800c96f56dbe5d1988a2cbe0b267aa/mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d", size = 4695 }, +] + [[package]] name = "packaging" version = "24.1" @@ -665,7 +713,7 @@ wheels = [ [[package]] name = "python-multipart" -version = "0.0.10" +version = "0.0.11" source = { editable = "." } [package.dev-dependencies] @@ -680,6 +728,7 @@ dev = [ { name = "mkdocs-material" }, { name = "mkdocstrings-python" }, { name = "more-itertools" }, + { name = "mypy" }, { name = "pbr" }, { name = "pluggy" }, { name = "py" }, @@ -688,6 +737,7 @@ dev = [ { name = "pytest-timeout" }, { name = "pyyaml" }, { name = "ruff" }, + { name = "types-pyyaml" }, ] [package.metadata] @@ -704,6 +754,7 @@ dev = [ { name = "mkdocs-material" }, { name = "mkdocstrings-python" }, { name = "more-itertools", specifier = "==10.2.0" }, + { name = "mypy" }, { name = "pbr", specifier = "==6.0.0" }, { name = "pluggy", specifier = "==1.4.0" }, { name = "py", specifier = "==1.11.0" }, @@ -712,6 +763,7 @@ dev = [ { name = "pytest-timeout", specifier = "==2.3.1" }, { name = "pyyaml", specifier = "==6.0.1" }, { name = "ruff", specifier = "==0.3.4" }, + { name = "types-pyyaml" }, ] [[package]] @@ -939,6 +991,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/97/75/10a9ebee3fd790d20926a90a2547f0bf78f371b2f13aa822c759680ca7b9/tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc", size = 12757 }, ] +[[package]] +name = "types-pyyaml" +version = "6.0.12.20240917" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/92/7d/a95df0a11f95c8f48d7683f03e4aed1a2c0fc73e9de15cca4d38034bea1a/types-PyYAML-6.0.12.20240917.tar.gz", hash = "sha256:d1405a86f9576682234ef83bcb4e6fff7c9305c8b1fbad5e0bcd4f7dbdc9c587", size = 12381 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9e/2c/c1d81d680997d24b0542aa336f0a65bd7835e5224b7670f33a7d617da379/types_PyYAML-6.0.12.20240917-py3-none-any.whl", hash = "sha256:392b267f1c0fe6022952462bf5d6523f31e37f6cea49b14cee7ad634b6301570", size = 15264 }, +] + [[package]] name = "typing-extensions" version = "4.12.2"