Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preserve header casing #1338

Merged
merged 2 commits into from
Oct 6, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 60 additions & 45 deletions httpx/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,27 +526,28 @@ class Headers(typing.MutableMapping[str, str]):

def __init__(self, headers: HeaderTypes = None, encoding: str = None) -> None:
if headers is None:
self._list = [] # type: typing.List[typing.Tuple[bytes, bytes]]
self._list = [] # type: typing.List[typing.Tuple[bytes, bytes, bytes]]
elif isinstance(headers, Headers):
self._list = list(headers.raw)
self._list = list(headers._list)
elif isinstance(headers, dict):
self._list = [
(normalize_header_key(k, encoding), normalize_header_value(v, encoding))
(
normalize_header_key(k, lower=False, encoding=encoding),
normalize_header_key(k, lower=True, encoding=encoding),
normalize_header_value(v, encoding),
)
for k, v in headers.items()
]
else:
self._list = [
(normalize_header_key(k, encoding), normalize_header_value(v, encoding))
(
normalize_header_key(k, lower=False, encoding=encoding),
normalize_header_key(k, lower=True, encoding=encoding),
normalize_header_value(v, encoding),
)
for k, v in headers
]

self._dict = {} # type: typing.Dict[bytes, bytes]
for key, value in self._list:
if key in self._dict:
self._dict[key] = self._dict[key] + b", " + value
else:
self._dict[key] = value

self._encoding = encoding

@property
Expand Down Expand Up @@ -583,25 +584,36 @@ def raw(self) -> typing.List[typing.Tuple[bytes, bytes]]:
"""
Returns a list of the raw header items, as byte pairs.
"""
return list(self._list)
return [(raw_key, value) for raw_key, _, value in self._list]

def keys(self) -> typing.KeysView[str]:
return {key.decode(self.encoding): None for key in self._dict.keys()}.keys()
return {key.decode(self.encoding): None for _, key, value in self._list}.keys()

def values(self) -> typing.ValuesView[str]:
return {
key: value.decode(self.encoding) for key, value in self._dict.items()
}.values()
values_dict: typing.Dict[str, str] = {}
for _, key, value in self._list:
str_key = key.decode(self.encoding)
str_value = value.decode(self.encoding)
if str_key in values_dict:
values_dict[str_key] += f", {str_value}"
else:
values_dict[str_key] = str_value
return values_dict.values()

def items(self) -> typing.ItemsView[str, str]:
"""
Return `(key, value)` items of headers. Concatenate headers
into a single comma seperated value when a key occurs multiple times.
"""
return {
key.decode(self.encoding): value.decode(self.encoding)
for key, value in self._dict.items()
}.items()
values_dict: typing.Dict[str, str] = {}
for _, key, value in self._list:
str_key = key.decode(self.encoding)
str_value = value.decode(self.encoding)
if str_key in values_dict:
values_dict[str_key] += f", {str_value}"
else:
values_dict[str_key] = str_value
return values_dict.items()

def multi_items(self) -> typing.List[typing.Tuple[str, str]]:
"""
Expand All @@ -611,7 +623,7 @@ def multi_items(self) -> typing.List[typing.Tuple[str, str]]:
"""
return [
(key.decode(self.encoding), value.decode(self.encoding))
for key, value in self._list
for _, key, value in self._list
]

def get(self, key: str, default: typing.Any = None) -> typing.Any:
Expand All @@ -634,8 +646,8 @@ def get_list(self, key: str, split_commas: bool = False) -> typing.List[str]:

values = [
item_value.decode(self.encoding)
for item_key, item_value in self._list
if item_key == get_header_key
for _, item_key, item_value in self._list
if item_key.lower() == get_header_key
]

if not split_commas:
Expand All @@ -648,11 +660,11 @@ def get_list(self, key: str, split_commas: bool = False) -> typing.List[str]:

def update(self, headers: HeaderTypes = None) -> None: # type: ignore
headers = Headers(headers)
for header in headers:
self[header] = headers[header]
for key, value in headers.raw:
self[key.decode(headers.encoding)] = value.decode(headers.encoding)

def copy(self) -> "Headers":
return Headers(dict(self.items()), encoding=self.encoding)
return Headers(self, encoding=self.encoding)

def __getitem__(self, key: str) -> str:
"""
Expand All @@ -664,7 +676,7 @@ def __getitem__(self, key: str) -> str:
normalized_key = key.lower().encode(self.encoding)

items = []
for header_key, header_value in self._list:
for _, header_key, header_value in self._list:
if header_key == normalized_key:
items.append(header_value.decode(self.encoding))

Expand All @@ -678,44 +690,44 @@ def __setitem__(self, key: str, value: str) -> None:
Set the header `key` to `value`, removing any duplicate entries.
Retains insertion order.
"""
set_key = key.lower().encode(self._encoding or "utf-8")
set_key = key.encode(self._encoding or "utf-8")
set_value = value.encode(self._encoding or "utf-8")

self._dict[set_key] = set_value
lookup_key = set_key.lower()

found_indexes = []
for idx, (item_key, _) in enumerate(self._list):
if item_key == set_key:
for idx, (_, item_key, _) in enumerate(self._list):
if item_key == lookup_key:
found_indexes.append(idx)

for idx in reversed(found_indexes[1:]):
del self._list[idx]

if found_indexes:
idx = found_indexes[0]
self._list[idx] = (set_key, set_value)
self._list[idx] = (set_key, lookup_key, set_value)
else:
self._list.append((set_key, set_value))
self._list.append((set_key, lookup_key, set_value))

def __delitem__(self, key: str) -> None:
"""
Remove the header `key`.
"""
del_key = key.lower().encode(self.encoding)

del self._dict[del_key]

pop_indexes = []
for idx, (item_key, _) in enumerate(self._list):
if item_key == del_key:
for idx, (_, item_key, _) in enumerate(self._list):
if item_key.lower() == del_key:
pop_indexes.append(idx)

if not pop_indexes:
raise KeyError(key)

for idx in reversed(pop_indexes):
del self._list[idx]

def __contains__(self, key: typing.Any) -> bool:
header_key = key.lower().encode(self.encoding)
return header_key in self._dict
return header_key in [key for _, key, _ in self._list]

def __iter__(self) -> typing.Iterator[typing.Any]:
return iter(self.keys())
Expand All @@ -728,7 +740,10 @@ def __eq__(self, other: typing.Any) -> bool:
other_headers = Headers(other)
except ValueError:
return False
return sorted(self._list) == sorted(other_headers._list)

self_list = [(key, value) for _, key, value in self._list]
other_list = [(key, value) for _, key, value in other_headers._list]
return sorted(self_list) == sorted(other_list)

def __repr__(self) -> str:
class_name = self.__class__.__name__
Expand Down Expand Up @@ -794,15 +809,15 @@ def __init__(
def _prepare(self, default_headers: typing.Dict[str, str]) -> None:
for key, value in default_headers.items():
# Ignore Transfer-Encoding if the Content-Length has been set explicitly.
if key.lower() == "transfer-encoding" and "content-length" in self.headers:
if key.lower() == "transfer-encoding" and "Content-Length" in self.headers:
continue
self.headers.setdefault(key, value)

auto_headers: typing.List[typing.Tuple[bytes, bytes]] = []

has_host = "host" in self.headers
has_host = "Host" in self.headers
has_content_length = (
"content-length" in self.headers or "transfer-encoding" in self.headers
"Content-Length" in self.headers or "Transfer-Encoding" in self.headers
)

if not has_host and self.url.host:
Expand All @@ -811,9 +826,9 @@ def _prepare(self, default_headers: typing.Dict[str, str]) -> None:
host_header = self.url.host.encode("ascii")
else:
host_header = self.url.netloc.encode("ascii")
auto_headers.append((b"host", host_header))
auto_headers.append((b"Host", host_header))
if not has_content_length and self.method in ("POST", "PUT", "PATCH"):
auto_headers.append((b"content-length", b"0"))
auto_headers.append((b"Content-Length", b"0"))

self.headers = Headers(auto_headers + self.headers.raw)

Expand Down
11 changes: 8 additions & 3 deletions httpx/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,19 @@


def normalize_header_key(
value: typing.Union[str, bytes], encoding: str = None
value: typing.Union[str, bytes],
lower: bool,
encoding: str = None,
) -> bytes:
"""
Coerce str/bytes into a strictly byte-wise HTTP header key.
"""
if isinstance(value, bytes):
return value.lower()
return value.encode(encoding or "ascii").lower()
bytes_value = value
else:
bytes_value = value.encode(encoding or "ascii")

return bytes_value.lower() if lower else bytes_value


def normalize_header_value(
Expand Down
30 changes: 30 additions & 0 deletions tests/client/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest

import httpx
from tests.utils import MockTransport


def test_get(server):
Expand Down Expand Up @@ -271,3 +272,32 @@ def test_that_client_is_closed_after_with_block():
pass

assert client.is_closed


def echo_raw_headers(request: httpx.Request) -> httpx.Response:
data = [
(name.decode("ascii"), value.decode("ascii"))
for name, value in request.headers.raw
]
return httpx.Response(200, json=data)


def test_raw_client_header():
"""
Set a header in the Client.
"""
url = "http://example.org/echo_headers"
headers = {"Example-Header": "example-value"}

client = httpx.Client(transport=MockTransport(echo_raw_headers), headers=headers)
response = client.get(url)

assert response.status_code == 200
assert response.json() == [
["Host", "example.org"],
["Accept", "*/*"],
["Accept-Encoding", "gzip, deflate, br"],
["Connection", "keep-alive"],
["User-Agent", f"python-httpx/{httpx.__version__}"],
["Example-Header", "example-value"],
]