diff --git a/CHANGES/8768.bugfix.rst b/CHANGES/8768.bugfix.rst new file mode 100644 index 00000000000..18512163572 --- /dev/null +++ b/CHANGES/8768.bugfix.rst @@ -0,0 +1 @@ +Used more precise type for ``ClientResponseError.headers``, fixing some type errors when using them -- by :user:`Dreamorcerer`. diff --git a/aiohttp/client_exceptions.py b/aiohttp/client_exceptions.py index 6a8bb5f8035..eb5e1b09692 100644 --- a/aiohttp/client_exceptions.py +++ b/aiohttp/client_exceptions.py @@ -3,8 +3,10 @@ import asyncio from typing import TYPE_CHECKING, Optional, Tuple, Union +from multidict import MultiMapping + from .http_parser import RawResponseMessage -from .typedefs import LooseHeaders, StrOrURL +from .typedefs import StrOrURL try: import ssl @@ -69,7 +71,7 @@ def __init__( *, status: Optional[int] = None, message: str = "", - headers: Optional[LooseHeaders] = None, + headers: Optional[MultiMapping[str]] = None, ) -> None: self.request_info = request_info if status is not None: diff --git a/tests/test_client_exceptions.py b/tests/test_client_exceptions.py index 6ae7f0dca71..9ec9f48b27b 100644 --- a/tests/test_client_exceptions.py +++ b/tests/test_client_exceptions.py @@ -43,7 +43,7 @@ def test_pickle(self) -> None: history=(), status=400, message="Something wrong", - headers={}, + headers=CIMultiDict(foo="bar"), ) err.foo = "bar" # type: ignore[attr-defined] for proto in range(pickle.HIGHEST_PROTOCOL + 1): @@ -53,7 +53,8 @@ def test_pickle(self) -> None: assert err2.history == () assert err2.status == 400 assert err2.message == "Something wrong" - assert err2.headers == {} + # Use headers.get() to verify static type is correct. + assert err2.headers.get("foo") == "bar" assert err2.foo == "bar" def test_repr(self) -> None: @@ -65,11 +66,11 @@ def test_repr(self) -> None: history=(), status=400, message="Something wrong", - headers={}, + headers=CIMultiDict(), ) assert repr(err) == ( "ClientResponseError(%r, (), status=400, " - "message='Something wrong', headers={})" % (self.request_info,) + "message='Something wrong', headers=)" % (self.request_info,) ) def test_str(self) -> None: @@ -78,7 +79,7 @@ def test_str(self) -> None: history=(), status=400, message="Something wrong", - headers={}, + headers=CIMultiDict(), ) assert str(err) == ("400, message='Something wrong', url='http://example.com'")