Skip to content

Commit 56db30e

Browse files
srittauAvasam
andauthored
Make email.message.Message generic over the header type (#11732)
Co-authored-by: Avasam <samuel.06@hotmail.com>
1 parent c13f6e1 commit 56db30e

File tree

3 files changed

+58
-63
lines changed

3 files changed

+58
-63
lines changed

stdlib/email/message.pyi

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,33 +3,33 @@ from email import _ParamsType, _ParamType
33
from email.charset import Charset
44
from email.contentmanager import ContentManager
55
from email.errors import MessageDefect
6-
from email.header import Header
76
from email.policy import Policy
8-
from typing import Any, Literal, Protocol, TypeVar, overload
7+
from typing import Any, Generic, Literal, Protocol, TypeVar, overload
98
from typing_extensions import Self, TypeAlias
109

1110
__all__ = ["Message", "EmailMessage"]
1211

1312
_T = TypeVar("_T")
13+
# Type returned by Policy.header_fetch_parse, often str or Header.
14+
_HeaderT = TypeVar("_HeaderT", default=str)
15+
_HeaderParamT = TypeVar("_HeaderParamT", default=str)
16+
# Represents headers constructed by HeaderRegistry. Those are sub-classes
17+
# of BaseHeader and another header type.
18+
_HeaderRegistryT = TypeVar("_HeaderRegistryT", default=Any)
19+
_HeaderRegistryParamT = TypeVar("_HeaderRegistryParamT", default=Any)
20+
1421
_PayloadType: TypeAlias = Message | str
1522
_EncodedPayloadType: TypeAlias = Message | bytes
1623
_MultipartPayloadType: TypeAlias = list[_PayloadType]
1724
_CharsetType: TypeAlias = Charset | str | None
18-
# Type returned by Policy.header_fetch_parse, often str or Header.
19-
_HeaderType: TypeAlias = Any
20-
# Type accepted by Policy.header_store_parse.
21-
_HeaderTypeParam: TypeAlias = str | Header | Any
2225

2326
class _SupportsEncodeToPayload(Protocol):
2427
def encode(self, encoding: str, /) -> _PayloadType | _MultipartPayloadType | _SupportsDecodeToPayload: ...
2528

2629
class _SupportsDecodeToPayload(Protocol):
2730
def decode(self, encoding: str, errors: str, /) -> _PayloadType | _MultipartPayloadType: ...
2831

29-
# TODO: This class should be generic over the header policy and/or the header
30-
# value types allowed by the policy. This depends on PEP 696 support
31-
# (https://github.com/python/typeshed/issues/11422).
32-
class Message:
32+
class Message(Generic[_HeaderT, _HeaderParamT]):
3333
policy: Policy # undocumented
3434
preamble: str | None
3535
epilogue: str | None
@@ -70,24 +70,23 @@ class Message:
7070
# Same as `get` with `failobj=None`, but with the expectation that it won't return None in most scenarios
7171
# This is important for protocols using __getitem__, like SupportsKeysAndGetItem
7272
# Morally, the return type should be `AnyOf[_HeaderType, None]`,
73-
# which we could spell as `_HeaderType | Any`,
74-
# *but* `_HeaderType` itself is currently an alias to `Any`...
75-
def __getitem__(self, name: str) -> _HeaderType: ...
76-
def __setitem__(self, name: str, val: _HeaderTypeParam) -> None: ...
73+
# so using "the Any trick" instead.
74+
def __getitem__(self, name: str) -> _HeaderT | Any: ...
75+
def __setitem__(self, name: str, val: _HeaderParamT) -> None: ...
7776
def __delitem__(self, name: str) -> None: ...
7877
def keys(self) -> list[str]: ...
79-
def values(self) -> list[_HeaderType]: ...
80-
def items(self) -> list[tuple[str, _HeaderType]]: ...
78+
def values(self) -> list[_HeaderT]: ...
79+
def items(self) -> list[tuple[str, _HeaderT]]: ...
8180
@overload
82-
def get(self, name: str, failobj: None = None) -> _HeaderType | None: ...
81+
def get(self, name: str, failobj: None = None) -> _HeaderT | None: ...
8382
@overload
84-
def get(self, name: str, failobj: _T) -> _HeaderType | _T: ...
83+
def get(self, name: str, failobj: _T) -> _HeaderT | _T: ...
8584
@overload
86-
def get_all(self, name: str, failobj: None = None) -> list[_HeaderType] | None: ...
85+
def get_all(self, name: str, failobj: None = None) -> list[_HeaderT] | None: ...
8786
@overload
88-
def get_all(self, name: str, failobj: _T) -> list[_HeaderType] | _T: ...
87+
def get_all(self, name: str, failobj: _T) -> list[_HeaderT] | _T: ...
8988
def add_header(self, _name: str, _value: str, **_params: _ParamsType) -> None: ...
90-
def replace_header(self, _name: str, _value: _HeaderTypeParam) -> None: ...
89+
def replace_header(self, _name: str, _value: _HeaderParamT) -> None: ...
9190
def get_content_type(self) -> str: ...
9291
def get_content_maintype(self) -> str: ...
9392
def get_content_subtype(self) -> str: ...
@@ -141,14 +140,14 @@ class Message:
141140
) -> None: ...
142141
def __init__(self, policy: Policy = ...) -> None: ...
143142
# The following two methods are undocumented, but a source code comment states that they are public API
144-
def set_raw(self, name: str, value: _HeaderTypeParam) -> None: ...
145-
def raw_items(self) -> Iterator[tuple[str, _HeaderType]]: ...
143+
def set_raw(self, name: str, value: _HeaderParamT) -> None: ...
144+
def raw_items(self) -> Iterator[tuple[str, _HeaderT]]: ...
146145

147-
class MIMEPart(Message):
146+
class MIMEPart(Message[_HeaderRegistryT, _HeaderRegistryParamT]):
148147
def __init__(self, policy: Policy | None = None) -> None: ...
149-
def get_body(self, preferencelist: Sequence[str] = ("related", "html", "plain")) -> Message | None: ...
150-
def iter_attachments(self) -> Iterator[Message]: ...
151-
def iter_parts(self) -> Iterator[Message]: ...
148+
def get_body(self, preferencelist: Sequence[str] = ("related", "html", "plain")) -> MIMEPart[_HeaderRegistryT] | None: ...
149+
def iter_attachments(self) -> Iterator[MIMEPart[_HeaderRegistryT]]: ...
150+
def iter_parts(self) -> Iterator[MIMEPart[_HeaderRegistryT]]: ...
152151
def get_content(self, *args: Any, content_manager: ContentManager | None = None, **kw: Any) -> Any: ...
153152
def set_content(self, *args: Any, content_manager: ContentManager | None = None, **kw: Any) -> None: ...
154153
def make_related(self, boundary: str | None = None) -> None: ...

stdlib/email/parser.pyi

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,34 @@ from collections.abc import Callable
33
from email.feedparser import BytesFeedParser as BytesFeedParser, FeedParser as FeedParser
44
from email.message import Message
55
from email.policy import Policy
6-
from typing import IO
6+
from io import _WrappedBuffer
7+
from typing import Generic, TypeVar, overload
78

89
__all__ = ["Parser", "HeaderParser", "BytesParser", "BytesHeaderParser", "FeedParser", "BytesFeedParser"]
910

10-
class Parser:
11-
def __init__(self, _class: Callable[[], Message] | None = None, *, policy: Policy = ...) -> None: ...
12-
def parse(self, fp: SupportsRead[str], headersonly: bool = False) -> Message: ...
13-
def parsestr(self, text: str, headersonly: bool = False) -> Message: ...
11+
_MessageT = TypeVar("_MessageT", bound=Message, default=Message)
1412

15-
class HeaderParser(Parser):
16-
def parse(self, fp: SupportsRead[str], headersonly: bool = True) -> Message: ...
17-
def parsestr(self, text: str, headersonly: bool = True) -> Message: ...
13+
class Parser(Generic[_MessageT]):
14+
@overload
15+
def __init__(self: Parser[Message[str, str]], _class: None = None, *, policy: Policy = ...) -> None: ...
16+
@overload
17+
def __init__(self, _class: Callable[[], _MessageT], *, policy: Policy = ...) -> None: ...
18+
def parse(self, fp: SupportsRead[str], headersonly: bool = False) -> _MessageT: ...
19+
def parsestr(self, text: str, headersonly: bool = False) -> _MessageT: ...
1820

19-
class BytesParser:
20-
def __init__(self, _class: Callable[[], Message] = ..., *, policy: Policy = ...) -> None: ...
21-
def parse(self, fp: IO[bytes], headersonly: bool = False) -> Message: ...
22-
def parsebytes(self, text: bytes | bytearray, headersonly: bool = False) -> Message: ...
21+
class HeaderParser(Parser[_MessageT]):
22+
def parse(self, fp: SupportsRead[str], headersonly: bool = True) -> _MessageT: ...
23+
def parsestr(self, text: str, headersonly: bool = True) -> _MessageT: ...
2324

24-
class BytesHeaderParser(BytesParser):
25-
def parse(self, fp: IO[bytes], headersonly: bool = True) -> Message: ...
26-
def parsebytes(self, text: bytes | bytearray, headersonly: bool = True) -> Message: ...
25+
class BytesParser(Generic[_MessageT]):
26+
parser: Parser[_MessageT]
27+
@overload
28+
def __init__(self: BytesParser[Message[str, str]], _class: None = None, *, policy: Policy = ...) -> None: ...
29+
@overload
30+
def __init__(self, _class: Callable[[], _MessageT], *, policy: Policy = ...) -> None: ...
31+
def parse(self, fp: _WrappedBuffer, headersonly: bool = False) -> _MessageT: ...
32+
def parsebytes(self, text: bytes | bytearray, headersonly: bool = False) -> _MessageT: ...
33+
34+
class BytesHeaderParser(BytesParser[_MessageT]):
35+
def parse(self, fp: _WrappedBuffer, headersonly: bool = True) -> _MessageT: ...
36+
def parsebytes(self, text: bytes | bytearray, headersonly: bool = True) -> _MessageT: ...

stdlib/http/client.pyi

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import io
33
import ssl
44
import sys
55
import types
6-
from _typeshed import ReadableBuffer, SupportsRead, WriteableBuffer
6+
from _typeshed import ReadableBuffer, SupportsRead, SupportsReadline, WriteableBuffer
77
from collections.abc import Callable, Iterable, Iterator, Mapping
88
from socket import socket
99
from typing import Any, BinaryIO, TypeVar, overload
@@ -33,6 +33,7 @@ __all__ = [
3333

3434
_DataType: TypeAlias = SupportsRead[bytes] | Iterable[ReadableBuffer] | ReadableBuffer
3535
_T = TypeVar("_T")
36+
_MessageT = TypeVar("_MessageT", bound=email.message.Message)
3637

3738
HTTP_PORT: int
3839
HTTPS_PORT: int
@@ -97,28 +98,13 @@ NETWORK_AUTHENTICATION_REQUIRED: int
9798

9899
responses: dict[int, str]
99100

100-
class HTTPMessage(email.message.Message):
101+
class HTTPMessage(email.message.Message[str, str]):
101102
def getallmatchingheaders(self, name: str) -> list[str]: ... # undocumented
102-
# override below all of Message's methods that use `_HeaderType` / `_HeaderTypeParam` with `str`
103-
# `HTTPMessage` breaks the Liskov substitution principle by only intending for `str` headers
104-
# This is easier than making `Message` generic
105-
def __getitem__(self, name: str) -> str | None: ...
106-
def __setitem__(self, name: str, val: str) -> None: ... # type: ignore[override]
107-
def values(self) -> list[str]: ...
108-
def items(self) -> list[tuple[str, str]]: ...
109-
@overload
110-
def get(self, name: str, failobj: None = None) -> str | None: ...
111-
@overload
112-
def get(self, name: str, failobj: _T) -> str | _T: ...
113-
@overload
114-
def get_all(self, name: str, failobj: None = None) -> list[str] | None: ...
115-
@overload
116-
def get_all(self, name: str, failobj: _T) -> list[str] | _T: ...
117-
def replace_header(self, _name: str, _value: str) -> None: ... # type: ignore[override]
118-
def set_raw(self, name: str, value: str) -> None: ... # type: ignore[override]
119-
def raw_items(self) -> Iterator[tuple[str, str]]: ...
120103

121-
def parse_headers(fp: io.BufferedIOBase, _class: Callable[[], email.message.Message] = ...) -> HTTPMessage: ...
104+
@overload
105+
def parse_headers(fp: SupportsReadline[bytes], _class: Callable[[], _MessageT]) -> _MessageT: ...
106+
@overload
107+
def parse_headers(fp: SupportsReadline[bytes]) -> HTTPMessage: ...
122108

123109
class HTTPResponse(io.BufferedIOBase, BinaryIO): # type: ignore[misc] # incompatible method definitions in the base classes
124110
msg: HTTPMessage

0 commit comments

Comments
 (0)