Skip to content

Commit

Permalink
Move utility functions from _utils.py to _client.py (#3389)
Browse files Browse the repository at this point in the history
  • Loading branch information
RafaelWO authored Nov 15, 2024
1 parent b47d94c commit 7b19cd5
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 100 deletions.
43 changes: 35 additions & 8 deletions httpx/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,7 @@
TimeoutTypes,
)
from ._urls import URL, QueryParams
from ._utils import (
URLPattern,
get_environment_proxies,
is_https_redirect,
same_origin,
)
from ._utils import URLPattern, get_environment_proxies

if typing.TYPE_CHECKING:
import ssl # pragma: no cover
Expand All @@ -63,6 +58,38 @@
U = typing.TypeVar("U", bound="AsyncClient")


def _is_https_redirect(url: URL, location: URL) -> bool:
"""
Return 'True' if 'location' is a HTTPS upgrade of 'url'
"""
if url.host != location.host:
return False

return (
url.scheme == "http"
and _port_or_default(url) == 80
and location.scheme == "https"
and _port_or_default(location) == 443
)


def _port_or_default(url: URL) -> int | None:
if url.port is not None:
return url.port
return {"http": 80, "https": 443}.get(url.scheme)


def _same_origin(url: URL, other: URL) -> bool:
"""
Return 'True' if the given URLs share the same origin.
"""
return (
url.scheme == other.scheme
and url.host == other.host
and _port_or_default(url) == _port_or_default(other)
)


class UseClientDefault:
"""
For some parameters such as `auth=...` and `timeout=...` we need to be able
Expand Down Expand Up @@ -521,8 +548,8 @@ def _redirect_headers(self, request: Request, url: URL, method: str) -> Headers:
"""
headers = Headers(request.headers)

if not same_origin(url, request.url):
if not is_https_redirect(request.url, url):
if not _same_origin(url, request.url):
if not _is_https_redirect(request.url, url):
# Strip Authorization headers when responses are redirected
# away from the origin. (Except for direct HTTP to HTTPS redirects.)
headers.pop("Authorization", None)
Expand Down
32 changes: 0 additions & 32 deletions httpx/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,38 +27,6 @@ def primitive_value_to_str(value: PrimitiveData) -> str:
return str(value)


def port_or_default(url: URL) -> int | None:
if url.port is not None:
return url.port
return {"http": 80, "https": 443}.get(url.scheme)


def same_origin(url: URL, other: URL) -> bool:
"""
Return 'True' if the given URLs share the same origin.
"""
return (
url.scheme == other.scheme
and url.host == other.host
and port_or_default(url) == port_or_default(other)
)


def is_https_redirect(url: URL, location: URL) -> bool:
"""
Return 'True' if 'location' is a HTTPS upgrade of 'url'
"""
if url.host != location.host:
return False

return (
url.scheme == "http"
and port_or_default(url) == 80
and location.scheme == "https"
and port_or_default(location) == 443
)


def get_environment_proxies() -> dict[str, str | None]:
"""Gets proxy information from the environment"""

Expand Down
56 changes: 56 additions & 0 deletions tests/client/test_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,59 @@ def test_host_with_non_default_port_in_url():
def test_request_auto_headers():
request = httpx.Request("GET", "https://www.example.org/")
assert "host" in request.headers


def test_same_origin():
origin = httpx.URL("https://example.com")
request = httpx.Request("GET", "HTTPS://EXAMPLE.COM:443")

client = httpx.Client()
headers = client._redirect_headers(request, origin, "GET")

assert headers["Host"] == request.url.netloc.decode("ascii")


def test_not_same_origin():
origin = httpx.URL("https://example.com")
request = httpx.Request("GET", "HTTP://EXAMPLE.COM:80")

client = httpx.Client()
headers = client._redirect_headers(request, origin, "GET")

assert headers["Host"] == origin.netloc.decode("ascii")


def test_is_https_redirect():
url = httpx.URL("https://example.com")
request = httpx.Request(
"GET", "http://example.com", headers={"Authorization": "empty"}
)

client = httpx.Client()
headers = client._redirect_headers(request, url, "GET")

assert "Authorization" in headers


def test_is_not_https_redirect():
url = httpx.URL("https://www.example.com")
request = httpx.Request(
"GET", "http://example.com", headers={"Authorization": "empty"}
)

client = httpx.Client()
headers = client._redirect_headers(request, url, "GET")

assert "Authorization" not in headers


def test_is_not_https_redirect_if_not_default_ports():
url = httpx.URL("https://example.com:1337")
request = httpx.Request(
"GET", "http://example.com:9999", headers={"Authorization": "empty"}
)

client = httpx.Client()
headers = client._redirect_headers(request, url, "GET")

assert "Authorization" not in headers
61 changes: 1 addition & 60 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,7 @@
import pytest

import httpx
from httpx._utils import (
URLPattern,
get_environment_proxies,
)
from httpx._utils import URLPattern, get_environment_proxies


@pytest.mark.parametrize(
Expand Down Expand Up @@ -115,62 +112,6 @@ def test_get_environment_proxies(environment, proxies):
assert get_environment_proxies() == proxies


def test_same_origin():
origin = httpx.URL("https://example.com")
request = httpx.Request("GET", "HTTPS://EXAMPLE.COM:443")

client = httpx.Client()
headers = client._redirect_headers(request, origin, "GET")

assert headers["Host"] == request.url.netloc.decode("ascii")


def test_not_same_origin():
origin = httpx.URL("https://example.com")
request = httpx.Request("GET", "HTTP://EXAMPLE.COM:80")

client = httpx.Client()
headers = client._redirect_headers(request, origin, "GET")

assert headers["Host"] == origin.netloc.decode("ascii")


def test_is_https_redirect():
url = httpx.URL("https://example.com")
request = httpx.Request(
"GET", "http://example.com", headers={"Authorization": "empty"}
)

client = httpx.Client()
headers = client._redirect_headers(request, url, "GET")

assert "Authorization" in headers


def test_is_not_https_redirect():
url = httpx.URL("https://www.example.com")
request = httpx.Request(
"GET", "http://example.com", headers={"Authorization": "empty"}
)

client = httpx.Client()
headers = client._redirect_headers(request, url, "GET")

assert "Authorization" not in headers


def test_is_not_https_redirect_if_not_default_ports():
url = httpx.URL("https://example.com:1337")
request = httpx.Request(
"GET", "http://example.com:9999", headers={"Authorization": "empty"}
)

client = httpx.Client()
headers = client._redirect_headers(request, url, "GET")

assert "Authorization" not in headers


@pytest.mark.parametrize(
["pattern", "url", "expected"],
[
Expand Down

0 comments on commit 7b19cd5

Please sign in to comment.