Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Properly close session in AsyncInferenceClient #2496

Merged
merged 3 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 51 additions & 8 deletions src/huggingface_hub/inference/_generated/_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
List,
Literal,
Optional,
Set,
Union,
overload,
)
Expand Down Expand Up @@ -188,6 +189,9 @@ def __init__(
# OpenAI compatibility
self.base_url = base_url

# Keep track of the sessions to close them properly
self._sessions: Set["ClientSession"] = set()

def __repr__(self):
return f"<InferenceClient(model='{self.model if self.model else ''}', timeout={self.timeout})>"

Expand Down Expand Up @@ -282,10 +286,10 @@ async def post(
with _open_as_binary(data) as data_as_binary:
# Do not use context manager as we don't want to close the connection immediately when returning
# a stream
client = self._get_client_session(headers=headers)
session = self._get_client_session(headers=headers)

try:
response = await client.post(url, json=json, data=data_as_binary, proxy=self.proxies)
response = await session.post(url, json=json, data=data_as_binary, proxy=self.proxies)
response_error_payload = None
if response.status != 200:
try:
Expand All @@ -294,18 +298,18 @@ async def post(
pass
response.raise_for_status()
if stream:
return _async_yield_from(client, response)
return _async_yield_from(session, response)
else:
content = await response.read()
await client.close()
await session.close()
return content
except asyncio.TimeoutError as error:
await client.close()
await session.close()
# Convert any `TimeoutError` to a `InferenceTimeoutError`
raise InferenceTimeoutError(f"Inference call timed out: {url}") from error # type: ignore
except aiohttp.ClientResponseError as error:
error.response_error_payload = response_error_payload
await client.close()
await session.close()
if response.status == 422 and task is not None:
error.message += f". Make sure '{task}' task is supported by the model."
if response.status == 503:
Expand All @@ -327,9 +331,35 @@ async def post(
continue
raise error
except Exception:
await client.close()
await session.close()
raise

async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc_value, traceback):
await self.close()

def __del__(self):
if len(self._sessions) > 0:
warnings.warn(
"Deleting 'AsyncInferenceClient' client but some sessions are still open. "
"This can happen if you've stopped streaming data from the server before the stream was complete. "
"To close the client properly, you must call `await client.close()` "
"or use an async context (e.g. `async with AsyncInferenceClient(): ...`."
)
Comment on lines +343 to +350
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, shouldn't this just close the session as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well well well...

I got a hard time getting to make it work. The thing is that you don't know when __del__ is called by Python's garbage collector. Since self.close is an async function, you need an active event loop to make the close method run -which is not guaranteed. I first tried some hacks but gave up given the complexity compared to having a proper close method or using an async context manager.

This answer on stackoverflow also convinced me of doing things like this.


async def close(self):
"""Close all open sessions.

By default, 'aiohttp.ClientSession' objects are closed automatically when a call is completed. However, if you
are streaming data from the server and you stop before the stream is complete, you must call this method to
close the session properly.

Another possibility is to use an async context (e.g. `async with AsyncInferenceClient(): ...`).
"""
await asyncio.gather(*[session.close() for session in self._sessions])

async def audio_classification(
self,
audio: ContentT,
Expand Down Expand Up @@ -2610,13 +2640,26 @@ def _get_client_session(self, headers: Optional[Dict] = None) -> "ClientSession"
client_headers.update(headers)

# Return a new aiohttp ClientSession with correct settings.
return aiohttp.ClientSession(
session = aiohttp.ClientSession(
headers=client_headers,
cookies=self.cookies,
timeout=aiohttp.ClientTimeout(self.timeout),
trust_env=self.trust_env,
)

# Keep track of sessions to close them later
self._sessions.add(session)

# Override the 'close' method to deregister the session when closed
session._close = session.close

async def close_session():
await session._close()
self._sessions.discard(session)

session.close = close_session
return session

def _resolve_url(self, model: Optional[str] = None, task: Optional[str] = None) -> str:
model = model or self.model or self.base_url

Expand Down
159 changes: 159 additions & 0 deletions tests/cassettes/test_http_session_correctly_closed.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
interactions:
- request:
body: null
headers:
user-agent:
- unknown/None; hf_hub/0.25.0.dev0; python/3.10.12; torch/2.4.0; tensorflow/2.17.0;
fastcore/1.5.23
method: POST
uri: https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3.1-8B-Instruct
response:
body:
string: 'data: {"index":1,"token":{"id":279,"text":" the","logprob":-0.63378906,"special":false},"generated_text":"
the","details":null}


'
headers:
Access-Control-Allow-Credentials:
- 'true'
Access-Control-Allow-Origin:
- '*'
Cache-Control:
- no-cache
Connection:
- keep-alive
Content-Type:
- text/event-stream
Date:
- Fri, 30 Aug 2024 13:58:50 GMT
Transfer-Encoding:
- chunked
Vary:
- origin, access-control-request-method, access-control-request-headers, Origin,
Access-Control-Request-Method, Access-Control-Request-Headers
x-accel-buffering:
- 'no'
x-compute-characters:
- '41'
x-compute-type:
- 2-a10-g
x-request-id:
- 8ZMGhj7Cj90UK12dSJhfj
x-sha:
- 5206a32e0bd3067aef1ce90f5528ade7d866253f
status:
code: 200
message: OK
- request:
body: null
headers:
user-agent:
- unknown/None; hf_hub/0.25.0.dev0; python/3.10.12; torch/2.4.0; tensorflow/2.17.0;
fastcore/1.5.23
method: POST
uri: https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3.1-8B-Instruct
response:
body:
string: 'data: {"index":1,"token":{"id":279,"text":" the","logprob":-0.63378906,"special":false},"generated_text":"
the","details":null}


'
headers:
Access-Control-Allow-Credentials:
- 'true'
Connection:
- keep-alive
Content-Length:
- '129'
Content-Type:
- text/event-stream
Date:
- Fri, 30 Aug 2024 13:58:51 GMT
Vary:
- Origin, Access-Control-Request-Method, Access-Control-Request-Headers
x-compute-type:
- cache
x-request-id:
- gG5snagDwsUuzxQUSHpfs
x-sha:
- 5206a32e0bd3067aef1ce90f5528ade7d866253f
status:
code: 200
message: OK
- request:
body: null
headers:
user-agent:
- unknown/None; hf_hub/0.25.0.dev0; python/3.10.12; torch/2.4.0; tensorflow/2.17.0;
fastcore/1.5.23
method: POST
uri: https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3.1-8B-Instruct
response:
body:
string: 'data: {"index":1,"token":{"id":279,"text":" the","logprob":-0.63378906,"special":false},"generated_text":"
the","details":null}


'
headers:
Access-Control-Allow-Credentials:
- 'true'
Connection:
- keep-alive
Content-Length:
- '129'
Content-Type:
- text/event-stream
Date:
- Fri, 30 Aug 2024 13:58:51 GMT
Vary:
- Origin, Access-Control-Request-Method, Access-Control-Request-Headers
x-compute-type:
- cache
x-request-id:
- 8PS-J7-J2QKF0FyOb9War
x-sha:
- 5206a32e0bd3067aef1ce90f5528ade7d866253f
status:
code: 200
message: OK
- request:
body: null
headers:
user-agent:
- unknown/None; hf_hub/0.25.0.dev0; python/3.10.12; torch/2.4.0; tensorflow/2.17.0;
fastcore/1.5.23
method: POST
uri: https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3.1-8B-Instruct
response:
body:
string: 'data: {"index":1,"token":{"id":279,"text":" the","logprob":-0.63378906,"special":false},"generated_text":"
the","details":null}


'
headers:
Access-Control-Allow-Credentials:
- 'true'
Connection:
- keep-alive
Content-Length:
- '129'
Content-Type:
- text/event-stream
Date:
- Fri, 30 Aug 2024 13:58:52 GMT
Vary:
- Origin, Access-Control-Request-Method, Access-Control-Request-Headers
x-compute-type:
- cache
x-request-id:
- _MFpbB44jCKcoM8wai66z
x-sha:
- 5206a32e0bd3067aef1ce90f5528ade7d866253f
status:
code: 200
message: OK
version: 1
49 changes: 49 additions & 0 deletions tests/test_inference_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,3 +409,52 @@ async def test_openai_compatibility_with_stream_true():
chunked_text = [chunk.choices[0].delta.content async for chunk in output]
assert len(chunked_text) == 34
assert "".join(chunked_text) == "Here it goes:\n\n1, 2, 3, 4, 5, 6, 7, 8, 9, 10!"


@pytest.mark.vcr
@pytest.mark.asyncio
@with_production_testing
async def test_http_session_correctly_closed() -> None:
"""
Regression test for #2493.
Async client should close the HTTP session after the request is done.
This is always done except for streamed responses if the stream is not fully consumed.
Fixed by keeping a list of sessions and closing them all when deleting the client.

See https://github.com/huggingface/huggingface_hub/issues/2493.
"""

client = AsyncInferenceClient("meta-llama/Meta-Llama-3.1-8B-Instruct")
kwargs = {"prompt": "Hi", "stream": True, "max_new_tokens": 1}

# Test create session + close it + check correctly unregistered
await client.text_generation(**kwargs)
assert len(client._sessions) == 1
await list(client._sessions)[0].close()
assert len(client._sessions) == 0

# Test create multiple sessions + close AsyncInferenceClient + check correctly unregistered
await client.text_generation(**kwargs)
await client.text_generation(**kwargs)
await client.text_generation(**kwargs)

assert len(client._sessions) == 3
await client.close()
assert len(client._sessions) == 0


@pytest.mark.asyncio
async def test_use_async_with_inference_client():
with patch("huggingface_hub.AsyncInferenceClient.close") as mock_close:
async with AsyncInferenceClient():
pass
mock_close.assert_called_once()


@pytest.mark.asyncio
async def test_warns_if_client_deleted_with_opened_sessions():
client = AsyncInferenceClient()
session = client._get_client_session()
with pytest.warns(UserWarning):
client.__del__()
await session.close()
Loading
Loading