Skip to content

Commit

Permalink
Improve type hints on File (#111)
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex authored Feb 12, 2024
1 parent 2baf8b1 commit f4479c6
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 51 deletions.
19 changes: 10 additions & 9 deletions multipart/decoders.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import base64
import binascii
from io import BufferedWriter

from .exceptions import DecodeError

Expand Down Expand Up @@ -33,11 +34,11 @@ class Base64Decoder:
:param underlying: the underlying object to pass writes to
"""

def __init__(self, underlying):
def __init__(self, underlying: BufferedWriter):
self.cache = bytearray()
self.underlying = underlying

def write(self, data):
def write(self, data: bytes) -> int:
"""Takes any input data provided, decodes it as base64, and passes it
on to the underlying object. If the data provided is invalid base64
data, then this method will raise
Expand Down Expand Up @@ -80,7 +81,7 @@ def close(self) -> None:
if hasattr(self.underlying, "close"):
self.underlying.close()

def finalize(self):
def finalize(self) -> None:
"""Finalize this object. This should be called when no more data
should be written to the stream. This function can raise a
:class:`multipart.exceptions.DecodeError` if there is some remaining
Expand All @@ -97,7 +98,7 @@ def finalize(self):
if hasattr(self.underlying, "finalize"):
self.underlying.finalize()

def __repr__(self):
def __repr__(self) -> str:
return f"{self.__class__.__name__}(underlying={self.underlying!r})"


Expand All @@ -111,11 +112,11 @@ class QuotedPrintableDecoder:
:param underlying: the underlying object to pass writes to
"""

def __init__(self, underlying):
def __init__(self, underlying: BufferedWriter) -> None:
self.cache = b""
self.underlying = underlying

def write(self, data):
def write(self, data: bytes) -> int:
"""Takes any input data provided, decodes it as quoted-printable, and
passes it on to the underlying object.
Expand All @@ -142,14 +143,14 @@ def write(self, data):
self.cache = rest
return len(data)

def close(self):
def close(self) -> None:
"""Close this decoder. If the underlying object has a `close()`
method, this function will call it.
"""
if hasattr(self.underlying, "close"):
self.underlying.close()

def finalize(self):
def finalize(self) -> None:
"""Finalize this object. This should be called when no more data
should be written to the stream. This function will not raise any
exceptions, but it may write more data to the underlying object if
Expand All @@ -167,5 +168,5 @@ def finalize(self):
if hasattr(self.underlying, "finalize"):
self.underlying.finalize()

def __repr__(self):
def __repr__(self) -> str:
return f"{self.__class__.__name__}(underlying={self.underlying!r})"
74 changes: 33 additions & 41 deletions multipart/multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class MultipartCallbacks(TypedDict, total=False):
on_headers_finished: Callable[[], None]
on_end: Callable[[], None]

class FormParserConfig(TypedDict, total=False):
class FormParserConfig(TypedDict):
UPLOAD_DIR: str | None
UPLOAD_KEEP_FILENAME: bool
UPLOAD_KEEP_EXTENSIONS: bool
Expand All @@ -50,7 +50,7 @@ class FormParserConfig(TypedDict, total=False):
MAX_BODY_SIZE: float

class FileConfig(TypedDict, total=False):
UPLOAD_DIR: str | None
UPLOAD_DIR: str | bytes | None
UPLOAD_DELETE_TMP: bool
UPLOAD_KEEP_FILENAME: bool
UPLOAD_KEEP_EXTENSIONS: bool
Expand Down Expand Up @@ -374,7 +374,7 @@ class File:
configuration keys and their corresponding values.
"""

def __init__(self, file_name: bytes | None, field_name: bytes | None = None, config: FileConfig = {}):
def __init__(self, file_name: bytes | None, field_name: bytes | None = None, config: FileConfig = {}) -> None:
# Save configuration, set other variables default.
self.logger = logging.getLogger(__name__)
self._config = config
Expand Down Expand Up @@ -471,7 +471,7 @@ def flush_to_disk(self) -> None:
# Close the old file object.
old_fileobj.close()

def _get_disk_file(self):
def _get_disk_file(self) -> io.BufferedRandom | tempfile._TemporaryFileWrapper[bytes]: # type: ignore[reportPrivateUsage]
"""This function is responsible for getting a file object on-disk for us."""
self.logger.info("Opening a file on disk")

Expand All @@ -486,9 +486,7 @@ def _get_disk_file(self):

# Build our filename.
# TODO: what happens if we don't have a filename?
fname = self._file_base
if keep_extensions:
fname = fname + self._ext
fname = self._file_base + self._ext if keep_extensions else self._file_base

path = os.path.join(file_dir, fname)
try:
Expand All @@ -503,25 +501,21 @@ def _get_disk_file(self):
# Build options array.
# Note that on Python 3, tempfile doesn't support byte names. We
# encode our paths using the default filesystem encoding.
options = {}
if keep_extensions:
ext = self._ext
if isinstance(ext, bytes):
ext = ext.decode(sys.getfilesystemencoding())

options["suffix"] = ext
if file_dir is not None:
d = file_dir
if isinstance(d, bytes):
d = d.decode(sys.getfilesystemencoding())
suffix = self._ext.decode(sys.getfilesystemencoding()) if keep_extensions else None

options["dir"] = d
options["delete"] = delete_tmp
if file_dir is None:
dir = None
elif isinstance(file_dir, bytes):
dir = file_dir.decode(sys.getfilesystemencoding())
else:
dir = file_dir

# Create a temporary (named) file with the appropriate settings.
self.logger.info("Creating a temporary file with options: %r", options)
self.logger.info(
"Creating a temporary file with options: %r", {"suffix": suffix, "delete": delete_tmp, "dir": dir}
)
try:
tmp_file = tempfile.NamedTemporaryFile(**options)
tmp_file = tempfile.NamedTemporaryFile(suffix=suffix, delete=delete_tmp, dir=dir)
except OSError:
self.logger.exception("Error creating named temporary file")
raise FileError("Error creating named temporary file")
Expand Down Expand Up @@ -563,11 +557,8 @@ def on_data(self, data: bytes) -> int:
self._bytes_written += bwritten

# If we're in-memory and are over our limit, we create a file.
if (
self._in_memory
and self._config.get("MAX_MEMORY_FILE_SIZE") is not None
and (self._bytes_written > self._config.get("MAX_MEMORY_FILE_SIZE"))
):
max_memory_file_size = self._config.get("MAX_MEMORY_FILE_SIZE")
if self._in_memory and max_memory_file_size is not None and (self._bytes_written > max_memory_file_size):
self.logger.info("Flushing to disk")
self.flush_to_disk()

Expand Down Expand Up @@ -617,9 +608,7 @@ class BaseParser:
performance.
"""

callbacks: dict[str, Callable[..., Any]]

def __init__(self):
def __init__(self) -> None:
self.logger = logging.getLogger(__name__)

def callback(self, name: str, data: bytes | None = None, start: int | None = None, end: int | None = None):
Expand Down Expand Up @@ -706,7 +695,7 @@ class OctetStreamParser(BaseParser):
i.e. unbounded.
"""

def __init__(self, callbacks: OctetStreamCallbacks = {}, max_size=float("inf")):
def __init__(self, callbacks: OctetStreamCallbacks = {}, max_size: float = float("inf")):
super().__init__()
self.callbacks = callbacks
self._started = False
Expand All @@ -716,7 +705,7 @@ def __init__(self, callbacks: OctetStreamCallbacks = {}, max_size=float("inf")):
self.max_size = max_size
self._current_size = 0

def write(self, data: bytes):
def write(self, data: bytes) -> int:
"""Write some data to the parser, which will perform size verification,
and then pass the data to the underlying callback.
Expand Down Expand Up @@ -803,7 +792,9 @@ class QuerystringParser(BaseParser):

state: QuerystringState

def __init__(self, callbacks: QuerystringCallbacks = {}, strict_parsing: bool = False, max_size=float("inf")):
def __init__(
self, callbacks: QuerystringCallbacks = {}, strict_parsing: bool = False, max_size: float = float("inf")
) -> None:
super().__init__()
self.state = QuerystringState.BEFORE_FIELD
self._found_sep = False
Expand Down Expand Up @@ -1060,7 +1051,9 @@ class MultipartParser(BaseParser):
i.e. unbounded.
"""

def __init__(self, boundary: bytes | str, callbacks: MultipartCallbacks = {}, max_size=float("inf")):
def __init__(
self, boundary: bytes | str, callbacks: MultipartCallbacks = {}, max_size: float = float("inf")
) -> None:
# Initialize parser state.
super().__init__()
self.state = MultipartState.START
Expand Down Expand Up @@ -1618,8 +1611,8 @@ def __init__(
file_name: bytes | None = None,
FileClass: type[FileProtocol] = File,
FieldClass: type[FieldProtocol] = Field,
config: FormParserConfig = {},
):
config: dict[Any, Any] = {},
) -> None:
self.logger = logging.getLogger(__name__)

# Save variables.
Expand Down Expand Up @@ -1787,7 +1780,7 @@ def on_headers_finished() -> None:
# TODO: check that we properly handle 8bit / 7bit encoding.
transfer_encoding = headers.get(b"Content-Transfer-Encoding", b"7bit")

if transfer_encoding == b"binary" or transfer_encoding == b"8bit" or transfer_encoding == b"7bit":
if transfer_encoding in (b"binary", b"8bit", b"7bit"):
writer = f

elif transfer_encoding == b"base64":
Expand Down Expand Up @@ -1862,8 +1855,8 @@ def create_form_parser(
on_field: OnFieldCallback,
on_file: OnFileCallback,
trust_x_headers: bool = False,
config={},
):
config: dict[Any, Any] = {},
) -> FormParser:
"""This function is a helper function to aid in creating a FormParser
instances. Given a dictionary-like headers object, it will determine
the correct information needed, instantiate a FormParser with the
Expand Down Expand Up @@ -1912,8 +1905,7 @@ def parse_form(
on_field: OnFieldCallback,
on_file: OnFileCallback,
chunk_size: int = 1048576,
**kwargs,
):
) -> None:
"""This function is useful if you just want to parse a request body,
without too much work. Pass it a dictionary-like object of the request's
headers, and a file-like object for the input stream, along with two
Expand Down
6 changes: 5 additions & 1 deletion tests/test_multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import tempfile
import unittest
from io import BytesIO
from typing import TYPE_CHECKING
from unittest.mock import Mock

import yaml
Expand All @@ -28,6 +29,9 @@

from .compat import parametrize, parametrize_class, slow_test

if TYPE_CHECKING:
from multipart.multipart import FileConfig

# Get the current directory for our later test cases.
curr_dir = os.path.abspath(os.path.dirname(__file__))

Expand Down Expand Up @@ -95,7 +99,7 @@ def test_set_none(self):

class TestFile(unittest.TestCase):
def setUp(self):
self.c = {}
self.c: FileConfig = {}
self.d = force_bytes(tempfile.mkdtemp())
self.f = File(b"foo.txt", config=self.c)

Expand Down

0 comments on commit f4479c6

Please sign in to comment.