Skip to content

Commit

Permalink
Backport #8620 as improvements to various type annotations (#8634)
Browse files Browse the repository at this point in the history
(cherry picked from commit c7293e1)
  • Loading branch information
Dreamsorcerer authored and patchback[bot] committed Aug 7, 2024
1 parent 4815765 commit 0b88abf
Show file tree
Hide file tree
Showing 14 changed files with 188 additions and 76 deletions.
3 changes: 3 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,6 @@ omit = site-packages
[report]
exclude_also =
if TYPE_CHECKING
assert False
: \.\.\.(\s*#.*)?$
^ +\.\.\.$
2 changes: 1 addition & 1 deletion .github/workflows/ci-cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ jobs:
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: 3.9
python-version: 3.11
- name: Cache PyPI
uses: actions/cache@v4.0.2
with:
Expand Down
1 change: 1 addition & 0 deletions CHANGES/8634.misc.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Minor improvements to various type annotations -- by :user:`Dreamsorcerer`.
29 changes: 20 additions & 9 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import traceback
import warnings
from contextlib import suppress
from types import SimpleNamespace, TracebackType
from types import TracebackType
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -155,7 +155,7 @@


class _RequestOptions(TypedDict, total=False):
params: Union[Mapping[str, str], None]
params: Union[Mapping[str, Union[str, int]], str, None]
data: Any
json: Any
cookies: Union[LooseCookies, None]
Expand All @@ -175,7 +175,7 @@ class _RequestOptions(TypedDict, total=False):
ssl: Union[SSLContext, bool, Fingerprint]
server_hostname: Union[str, None]
proxy_headers: Union[LooseHeaders, None]
trace_request_ctx: Union[SimpleNamespace, None]
trace_request_ctx: Union[Mapping[str, str], None]
read_bufsize: Union[int, None]
auto_decompress: Union[bool, None]
max_line_size: Union[int, None]
Expand Down Expand Up @@ -422,11 +422,22 @@ def __del__(self, _warnings: Any = warnings) -> None:
context["source_traceback"] = self._source_traceback
self._loop.call_exception_handler(context)

def request(
self, method: str, url: StrOrURL, **kwargs: Any
) -> "_RequestContextManager":
"""Perform HTTP request."""
return _RequestContextManager(self._request(method, url, **kwargs))
if sys.version_info >= (3, 11) and TYPE_CHECKING:

def request(
self,
method: str,
url: StrOrURL,
**kwargs: Unpack[_RequestOptions],
) -> "_RequestContextManager": ...

else:

def request(
self, method: str, url: StrOrURL, **kwargs: Any
) -> "_RequestContextManager":
"""Perform HTTP request."""
return _RequestContextManager(self._request(method, url, **kwargs))

def _build_url(self, str_or_url: StrOrURL) -> URL:
url = URL(str_or_url)
Expand Down Expand Up @@ -466,7 +477,7 @@ async def _request(
ssl: Union[SSLContext, bool, Fingerprint] = True,
server_hostname: Optional[str] = None,
proxy_headers: Optional[LooseHeaders] = None,
trace_request_ctx: Optional[SimpleNamespace] = None,
trace_request_ctx: Optional[Mapping[str, str]] = None,
read_bufsize: Optional[int] = None,
auto_decompress: Optional[bool] = None,
max_line_size: Optional[int] = None,
Expand Down
2 changes: 1 addition & 1 deletion aiohttp/client_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def __str__(self) -> str:
return "{}, message={!r}, url={!r}".format(
self.status,
self.message,
self.request_info.real_url,
str(self.request_info.real_url),
)

def __repr__(self) -> str:
Expand Down
13 changes: 9 additions & 4 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,8 @@ class ClientRequest:
hdrs.ACCEPT_ENCODING: _gen_default_accept_encoding(),
}

body = b""
# Type of body depends on PAYLOAD_REGISTRY, which is dynamic.
body: Any = b""
auth = None
response = None

Expand Down Expand Up @@ -441,7 +442,7 @@ def update_headers(self, headers: Optional[LooseHeaders]) -> None:

if headers:
if isinstance(headers, (dict, MultiDictProxy, MultiDict)):
headers = headers.items() # type: ignore[assignment]
headers = headers.items()

for key, value in headers: # type: ignore[misc]
# A special case for Host header
Expand Down Expand Up @@ -597,6 +598,10 @@ def update_proxy(
raise ValueError("proxy_auth must be None or BasicAuth() tuple")
self.proxy = proxy
self.proxy_auth = proxy_auth
if proxy_headers is not None and not isinstance(
proxy_headers, (MultiDict, MultiDictProxy)
):
proxy_headers = CIMultiDict(proxy_headers)
self.proxy_headers = proxy_headers

def keep_alive(self) -> bool:
Expand Down Expand Up @@ -632,10 +637,10 @@ async def write_bytes(
await self.body.write(writer)
else:
if isinstance(self.body, (bytes, bytearray)):
self.body = (self.body,) # type: ignore[assignment]
self.body = (self.body,)

for chunk in self.body:
await writer.write(chunk) # type: ignore[arg-type]
await writer.write(chunk)
except OSError as underlying_exc:
reraised_exc = underlying_exc

Expand Down
5 changes: 3 additions & 2 deletions aiohttp/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
List,
Literal,
Optional,
Sequence,
Set,
Tuple,
Type,
Expand Down Expand Up @@ -833,7 +834,7 @@ def clear_dns_cache(
self._cached_hosts.clear()

async def _resolve_host(
self, host: str, port: int, traces: Optional[List["Trace"]] = None
self, host: str, port: int, traces: Optional[Sequence["Trace"]] = None
) -> List[ResolveResult]:
"""Resolve host and return list of addresses."""
if is_ip_address(host):
Expand Down Expand Up @@ -902,7 +903,7 @@ async def _resolve_host_with_throttle(
key: Tuple[str, int],
host: str,
port: int,
traces: Optional[List["Trace"]],
traces: Optional[Sequence["Trace"]],
) -> List[ResolveResult]:
"""Resolve host with a dns events throttle."""
if key in self._throttle_dns_events:
Expand Down
34 changes: 30 additions & 4 deletions aiohttp/pytest_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,17 @@
import contextlib
import inspect
import warnings
from typing import Any, Awaitable, Callable, Dict, Iterator, Optional, Type, Union
from typing import (
Any,
Awaitable,
Callable,
Dict,
Iterator,
Optional,
Protocol,
Type,
Union,
)

import pytest

Expand All @@ -24,9 +34,23 @@
except ImportError: # pragma: no cover
uvloop = None # type: ignore[assignment]

AiohttpClient = Callable[[Union[Application, BaseTestServer]], Awaitable[TestClient]]
AiohttpRawServer = Callable[[Application], Awaitable[RawTestServer]]
AiohttpServer = Callable[[Application], Awaitable[TestServer]]


class AiohttpClient(Protocol):
def __call__(
self,
__param: Union[Application, BaseTestServer],
*,
server_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any
) -> Awaitable[TestClient]: ...


class AiohttpServer(Protocol):
def __call__(
self, app: Application, *, port: Optional[int] = None, **kwargs: Any
) -> Awaitable[TestServer]: ...


def pytest_addoption(parser): # type: ignore[no-untyped-def]
Expand Down Expand Up @@ -262,7 +286,9 @@ def aiohttp_server(loop: asyncio.AbstractEventLoop) -> Iterator[AiohttpServer]:
"""
servers = []

async def go(app, *, port=None, **kwargs): # type: ignore[no-untyped-def]
async def go(
app: Application, *, port: Optional[int] = None, **kwargs: Any
) -> TestServer:
server = TestServer(app, port=port)
await server.start_server(loop=loop, **kwargs)
servers.append(server)
Expand Down
Loading

0 comments on commit 0b88abf

Please sign in to comment.