Skip to content

Commit

Permalink
Fixed RuntimeWarning caused by unclosed AsyncClient (#871)
Browse files Browse the repository at this point in the history
  • Loading branch information
cdeler committed Aug 21, 2020
1 parent 77ba7a1 commit 458e6c9
Show file tree
Hide file tree
Showing 7 changed files with 178 additions and 135 deletions.
96 changes: 56 additions & 40 deletions tests/client/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,8 @@ async def test_basic_auth() -> None:
url = "https://example.org/"
auth = ("tomchristie", "password123")

client = AsyncClient(transport=AsyncMockTransport())
response = await client.get(url, auth=auth)
async with AsyncClient(transport=AsyncMockTransport()) as client:
response = await client.get(url, auth=auth)

assert response.status_code == 200
assert response.json() == {"auth": "Basic dG9tY2hyaXN0aWU6cGFzc3dvcmQxMjM="}
Expand All @@ -200,8 +200,8 @@ async def test_basic_auth() -> None:
async def test_basic_auth_in_url() -> None:
url = "https://tomchristie:password123@example.org/"

client = AsyncClient(transport=AsyncMockTransport())
response = await client.get(url)
async with AsyncClient(transport=AsyncMockTransport()) as client:
response = await client.get(url)

assert response.status_code == 200
assert response.json() == {"auth": "Basic dG9tY2hyaXN0aWU6cGFzc3dvcmQxMjM="}
Expand All @@ -212,8 +212,8 @@ async def test_basic_auth_on_session() -> None:
url = "https://example.org/"
auth = ("tomchristie", "password123")

client = AsyncClient(transport=AsyncMockTransport(), auth=auth)
response = await client.get(url)
async with AsyncClient(transport=AsyncMockTransport(), auth=auth) as client:
response = await client.get(url)

assert response.status_code == 200
assert response.json() == {"auth": "Basic dG9tY2hyaXN0aWU6cGFzc3dvcmQxMjM="}
Expand All @@ -227,8 +227,8 @@ def auth(request: Request) -> Request:
request.headers["Authorization"] = "Token 123"
return request

client = AsyncClient(transport=AsyncMockTransport())
response = await client.get(url, auth=auth)
async with AsyncClient(transport=AsyncMockTransport()) as client:
response = await client.get(url, auth=auth)

assert response.status_code == 200
assert response.json() == {"auth": "Token 123"}
Expand All @@ -239,8 +239,8 @@ async def test_netrc_auth() -> None:
os.environ["NETRC"] = str(FIXTURES_DIR / ".netrc")
url = "http://netrcexample.org"

client = AsyncClient(transport=AsyncMockTransport())
response = await client.get(url)
async with AsyncClient(transport=AsyncMockTransport()) as client:
response = await client.get(url)

assert response.status_code == 200
assert response.json() == {
Expand All @@ -253,8 +253,8 @@ async def test_auth_header_has_priority_over_netrc() -> None:
os.environ["NETRC"] = str(FIXTURES_DIR / ".netrc")
url = "http://netrcexample.org"

client = AsyncClient(transport=AsyncMockTransport())
response = await client.get(url, headers={"Authorization": "Override"})
async with AsyncClient(transport=AsyncMockTransport()) as client:
response = await client.get(url, headers={"Authorization": "Override"})

assert response.status_code == 200
assert response.json() == {"auth": "Override"}
Expand All @@ -265,14 +265,14 @@ async def test_trust_env_auth() -> None:
os.environ["NETRC"] = str(FIXTURES_DIR / ".netrc")
url = "http://netrcexample.org"

client = AsyncClient(transport=AsyncMockTransport(), trust_env=False)
response = await client.get(url)
async with AsyncClient(transport=AsyncMockTransport(), trust_env=False) as client:
response = await client.get(url)

assert response.status_code == 200
assert response.json() == {"auth": None}

client = AsyncClient(transport=AsyncMockTransport(), trust_env=True)
response = await client.get(url)
async with AsyncClient(transport=AsyncMockTransport(), trust_env=True) as client:
response = await client.get(url)

assert response.status_code == 200
assert response.json() == {
Expand All @@ -285,8 +285,8 @@ async def test_auth_disable_per_request() -> None:
url = "https://example.org/"
auth = ("tomchristie", "password123")

client = AsyncClient(transport=AsyncMockTransport(), auth=auth)
response = await client.get(url, auth=None)
async with AsyncClient(transport=AsyncMockTransport(), auth=auth) as client:
response = await client.get(url, auth=None)

assert response.status_code == 200
assert response.json() == {"auth": None}
Expand All @@ -304,8 +304,8 @@ async def test_auth_hidden_header() -> None:
url = "https://example.org/"
auth = ("example-username", "example-password")

client = AsyncClient(transport=AsyncMockTransport())
response = await client.get(url, auth=auth)
async with AsyncClient(transport=AsyncMockTransport()) as client:
response = await client.get(url, auth=auth)

assert "'authorization': '[secure]'" in str(response.request.headers)

Expand All @@ -323,6 +323,8 @@ async def test_auth_property() -> None:
assert response.status_code == 200
assert response.json() == {"auth": "Basic dG9tY2hyaXN0aWU6cGFzc3dvcmQxMjM="}

await client.aclose()


@pytest.mark.asyncio
async def test_auth_invalid_type() -> None:
Expand All @@ -340,14 +342,16 @@ async def test_auth_invalid_type() -> None:
with pytest.raises(TypeError):
client.auth = "not a tuple, not a callable" # type: ignore

await client.aclose()


@pytest.mark.asyncio
async def test_digest_auth_returns_no_auth_if_no_digest_header_in_response() -> None:
url = "https://example.org/"
auth = DigestAuth(username="tomchristie", password="password123")

client = AsyncClient(transport=AsyncMockTransport())
response = await client.get(url, auth=auth)
async with AsyncClient(transport=AsyncMockTransport()) as client:
response = await client.get(url, auth=auth)

assert response.status_code == 200
assert response.json() == {"auth": None}
Expand All @@ -360,10 +364,10 @@ async def test_digest_auth_200_response_including_digest_auth_header() -> None:
auth = DigestAuth(username="tomchristie", password="password123")
auth_header = b'Digest realm="realm@host.com",qop="auth",nonce="abc",opaque="xyz"'

client = AsyncClient(
async with AsyncClient(
transport=AsyncMockTransport(auth_header=auth_header, status_code=200)
)
response = await client.get(url, auth=auth)
) as client:
response = await client.get(url, auth=auth)

assert response.status_code == 200
assert response.json() == {"auth": None}
Expand All @@ -375,8 +379,10 @@ async def test_digest_auth_401_response_without_digest_auth_header() -> None:
url = "https://example.org/"
auth = DigestAuth(username="tomchristie", password="password123")

client = AsyncClient(transport=AsyncMockTransport(auth_header=b"", status_code=401))
response = await client.get(url, auth=auth)
async with AsyncClient(
transport=AsyncMockTransport(auth_header=b"", status_code=401)
) as client:
response = await client.get(url, auth=auth)

assert response.status_code == 401
assert response.json() == {"auth": None}
Expand All @@ -403,8 +409,10 @@ async def test_digest_auth(
url = "https://example.org/"
auth = DigestAuth(username="tomchristie", password="password123")

client = AsyncClient(transport=MockDigestAuthTransport(algorithm=algorithm))
response = await client.get(url, auth=auth)
async with AsyncClient(
transport=MockDigestAuthTransport(algorithm=algorithm)
) as client:
response = await client.get(url, auth=auth)

assert response.status_code == 200
assert len(response.history) == 1
Expand Down Expand Up @@ -433,8 +441,8 @@ async def test_digest_auth_no_specified_qop() -> None:
url = "https://example.org/"
auth = DigestAuth(username="tomchristie", password="password123")

client = AsyncClient(transport=MockDigestAuthTransport(qop=""))
response = await client.get(url, auth=auth)
async with AsyncClient(transport=MockDigestAuthTransport(qop="")) as client:
response = await client.get(url, auth=auth)

assert response.status_code == 200
assert len(response.history) == 1
Expand Down Expand Up @@ -464,8 +472,8 @@ async def test_digest_auth_qop_including_spaces_and_auth_returns_auth(qop: str)
url = "https://example.org/"
auth = DigestAuth(username="tomchristie", password="password123")

client = AsyncClient(transport=MockDigestAuthTransport(qop=qop))
response = await client.get(url, auth=auth)
async with AsyncClient(transport=MockDigestAuthTransport(qop=qop)) as client:
response = await client.get(url, auth=auth)

assert response.status_code == 200
assert len(response.history) == 1
Expand All @@ -480,6 +488,8 @@ async def test_digest_auth_qop_auth_int_not_implemented() -> None:
with pytest.raises(NotImplementedError):
await client.get(url, auth=auth)

await client.aclose()


@pytest.mark.asyncio
async def test_digest_auth_qop_must_be_auth_or_auth_int() -> None:
Expand All @@ -490,16 +500,18 @@ async def test_digest_auth_qop_must_be_auth_or_auth_int() -> None:
with pytest.raises(ProtocolError):
await client.get(url, auth=auth)

await client.aclose()


@pytest.mark.asyncio
async def test_digest_auth_incorrect_credentials() -> None:
url = "https://example.org/"
auth = DigestAuth(username="tomchristie", password="password123")

client = AsyncClient(
async with AsyncClient(
transport=MockDigestAuthTransport(send_response_after_attempt=2)
)
response = await client.get(url, auth=auth)
) as client:
response = await client.get(url, auth=auth)

assert response.status_code == 401
assert len(response.history) == 1
Expand Down Expand Up @@ -528,6 +540,8 @@ async def test_async_digest_auth_raises_protocol_error_on_malformed_header(
with pytest.raises(ProtocolError):
await client.get(url, auth=auth)

await client.aclose()


@pytest.mark.parametrize(
"auth_header",
Expand Down Expand Up @@ -560,9 +574,9 @@ async def test_async_auth_history() -> None:
"""
url = "https://example.org/"
auth = RepeatAuth(repeat=2)
client = AsyncClient(transport=AsyncMockTransport(auth_header=b"abc"))
async with AsyncClient(transport=AsyncMockTransport(auth_header=b"abc")) as client:
response = await client.get(url, auth=auth)

response = await client.get(url, auth=auth)
assert response.status_code == 200
assert response.json() == {"auth": "Repeat abc.abc"}

Expand Down Expand Up @@ -613,6 +627,8 @@ async def streaming_body():
with pytest.raises(RequestBodyUnavailable):
await client.post(url, data=streaming_body(), auth=auth)

await client.aclose()


@pytest.mark.asyncio
async def test_async_auth_reads_response_body() -> None:
Expand All @@ -622,9 +638,9 @@ async def test_async_auth_reads_response_body() -> None:
"""
url = "https://example.org/"
auth = ResponseBodyAuth("xyz")
client = AsyncClient(transport=AsyncMockTransport())
async with AsyncClient(transport=AsyncMockTransport()) as client:
response = await client.get(url, auth=auth)

response = await client.get(url, auth=auth)
assert response.status_code == 200
assert response.json() == {"auth": '{"auth": "xyz"}'}

Expand Down
22 changes: 13 additions & 9 deletions tests/client/test_cookies.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ async def test_set_cookie() -> None:
url = "http://example.org/echo_cookies"
cookies = {"example-name": "example-value"}

client = AsyncClient(transport=MockTransport())
response = await client.get(url, cookies=cookies)
async with AsyncClient(transport=MockTransport()) as client:
response = await client.get(url, cookies=cookies)

assert response.status_code == 200
assert response.json() == {"cookies": "example-name=example-value"}
Expand Down Expand Up @@ -85,8 +85,8 @@ async def test_set_cookie_with_cookiejar() -> None:
)
cookies.set_cookie(cookie)

client = AsyncClient(transport=MockTransport())
response = await client.get(url, cookies=cookies)
async with AsyncClient(transport=MockTransport()) as client:
response = await client.get(url, cookies=cookies)

assert response.status_code == 200
assert response.json() == {"cookies": "example-name=example-value"}
Expand Down Expand Up @@ -121,9 +121,9 @@ async def test_setting_client_cookies_to_cookiejar() -> None:
)
cookies.set_cookie(cookie)

client = AsyncClient(transport=MockTransport())
client.cookies = cookies # type: ignore
response = await client.get(url)
async with AsyncClient(transport=MockTransport()) as client:
client.cookies = cookies # type: ignore
response = await client.get(url)

assert response.status_code == 200
assert response.json() == {"cookies": "example-name=example-value"}
Expand All @@ -139,8 +139,8 @@ async def test_set_cookie_with_cookies_model() -> None:
cookies = Cookies()
cookies["example-name"] = "example-value"

client = AsyncClient(transport=MockTransport())
response = await client.get(url, cookies=cookies)
async with AsyncClient(transport=MockTransport()) as client:
response = await client.get(url, cookies=cookies)

assert response.status_code == 200
assert response.json() == {"cookies": "example-name=example-value"}
Expand All @@ -157,6 +157,8 @@ async def test_get_cookie() -> None:
assert response.cookies["example-name"] == "example-value"
assert client.cookies["example-name"] == "example-value"

await client.aclose()


@pytest.mark.asyncio
async def test_cookie_persistence() -> None:
Expand All @@ -177,3 +179,5 @@ async def test_cookie_persistence() -> None:
response = await client.get("http://example.org/echo_cookies")
assert response.status_code == 200
assert response.json() == {"cookies": "example-name=example-value"}

await client.aclose()
22 changes: 12 additions & 10 deletions tests/client/test_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ async def test_client_header():
url = "http://example.org/echo_headers"
headers = {"Example-Header": "example-value"}

client = AsyncClient(transport=MockTransport(), headers=headers)
response = await client.get(url)
async with AsyncClient(transport=MockTransport(), headers=headers) as client:
response = await client.get(url)

assert response.status_code == 200
assert response.json() == {
Expand All @@ -57,8 +57,8 @@ async def test_header_merge():
url = "http://example.org/echo_headers"
client_headers = {"User-Agent": "python-myclient/0.2.1"}
request_headers = {"X-Auth-Token": "FooBarBazToken"}
client = AsyncClient(transport=MockTransport(), headers=client_headers)
response = await client.get(url, headers=request_headers)
async with AsyncClient(transport=MockTransport(), headers=client_headers) as client:
response = await client.get(url, headers=request_headers)

assert response.status_code == 200
assert response.json() == {
Expand All @@ -78,8 +78,8 @@ async def test_header_merge_conflicting_headers():
url = "http://example.org/echo_headers"
client_headers = {"X-Auth-Token": "FooBar"}
request_headers = {"X-Auth-Token": "BazToken"}
client = AsyncClient(transport=MockTransport(), headers=client_headers)
response = await client.get(url, headers=request_headers)
async with AsyncClient(transport=MockTransport(), headers=client_headers) as client:
response = await client.get(url, headers=request_headers)

assert response.status_code == 200
assert response.json() == {
Expand Down Expand Up @@ -127,6 +127,8 @@ async def test_header_update():
}
}

await client.aclose()


def test_header_does_not_exist():
headers = Headers({"foo": "bar"})
Expand All @@ -143,8 +145,8 @@ async def test_host_with_auth_and_port_in_url():
"""
url = "http://username:password@example.org:80/echo_headers"

client = AsyncClient(transport=MockTransport())
response = await client.get(url)
async with AsyncClient(transport=MockTransport()) as client:
response = await client.get(url)

assert response.status_code == 200
assert response.json() == {
Expand All @@ -167,8 +169,8 @@ async def test_host_with_non_default_port_in_url():
"""
url = "http://username:password@example.org:123/echo_headers"

client = AsyncClient(transport=MockTransport())
response = await client.get(url)
async with AsyncClient(transport=MockTransport()) as client:
response = await client.get(url)

assert response.status_code == 200
assert response.json() == {
Expand Down
Loading

0 comments on commit 458e6c9

Please sign in to comment.