From fb8b3f02f8374487fdcff549228322f66a9bbf7b Mon Sep 17 00:00:00 2001 From: Lucain Pouget Date: Fri, 30 Aug 2024 14:34:27 +0200 Subject: [PATCH 1/3] Properly close session in AsyncInferenceClient --- .../inference/_generated/_async_client.py | 39 ++++++++++++---- utils/generate_async_inference_client.py | 44 +++++++++++++++---- 2 files changed, 66 insertions(+), 17 deletions(-) diff --git a/src/huggingface_hub/inference/_generated/_async_client.py b/src/huggingface_hub/inference/_generated/_async_client.py index f34b4e33fd..342161afd5 100644 --- a/src/huggingface_hub/inference/_generated/_async_client.py +++ b/src/huggingface_hub/inference/_generated/_async_client.py @@ -32,6 +32,7 @@ List, Literal, Optional, + Set, Union, overload, ) @@ -188,6 +189,9 @@ def __init__( # OpenAI compatibility self.base_url = base_url + # Keep track of the sessions to close them properly + self._sessions: Set["ClientSession"] = set() + def __repr__(self): return f"" @@ -282,10 +286,10 @@ async def post( with _open_as_binary(data) as data_as_binary: # Do not use context manager as we don't want to close the connection immediately when returning # a stream - client = self._get_client_session(headers=headers) + session = self._get_client_session(headers=headers) try: - response = await client.post(url, json=json, data=data_as_binary, proxy=self.proxies) + response = await session.post(url, json=json, data=data_as_binary, proxy=self.proxies) response_error_payload = None if response.status != 200: try: @@ -294,18 +298,18 @@ async def post( pass response.raise_for_status() if stream: - return _async_yield_from(client, response) + return _async_yield_from(session, response) else: content = await response.read() - await client.close() + await session.close() return content except asyncio.TimeoutError as error: - await client.close() + await session.close() # Convert any `TimeoutError` to a `InferenceTimeoutError` raise InferenceTimeoutError(f"Inference call timed out: {url}") from error # type: ignore except aiohttp.ClientResponseError as error: error.response_error_payload = response_error_payload - await client.close() + await session.close() if response.status == 422 and task is not None: error.message += f". Make sure '{task}' task is supported by the model." if response.status == 503: @@ -327,9 +331,15 @@ async def post( continue raise error except Exception: - await client.close() + await session.close() raise + def __del__(self): + async def _close_all_sessions(): + await asyncio.gather(*[session.close() for session in self._sessions]) + + asyncio.run(_close_all_sessions()) + async def audio_classification( self, audio: ContentT, @@ -2610,13 +2620,26 @@ def _get_client_session(self, headers: Optional[Dict] = None) -> "ClientSession" client_headers.update(headers) # Return a new aiohttp ClientSession with correct settings. - return aiohttp.ClientSession( + session = aiohttp.ClientSession( headers=client_headers, cookies=self.cookies, timeout=aiohttp.ClientTimeout(self.timeout), trust_env=self.trust_env, ) + # Keep track of sessions to close them later + self._sessions.add(session) + + # Override the 'close' method to deregister the session when closed + session._close = session.close + + async def close_session(): + await session._close() + self._sessions.discard(session) + + session.close = close_session + return session + def _resolve_url(self, model: Optional[str] = None, task: Optional[str] = None) -> str: model = model or self.model or self.base_url diff --git a/utils/generate_async_inference_client.py b/utils/generate_async_inference_client.py index 663bde255f..b4ad9a2d07 100644 --- a/utils/generate_async_inference_client.py +++ b/utils/generate_async_inference_client.py @@ -157,6 +157,7 @@ def _add_imports(code: str) -> str: r"\1" + "from .._common import _async_yield_from, _import_aiohttp\n" + "from typing import AsyncIterable\n" + + "from typing import Set\n" + "import asyncio\n" ), string=code, @@ -199,10 +200,10 @@ def _rename_to_AsyncInferenceClient(code: str) -> str: with _open_as_binary(data) as data_as_binary: # Do not use context manager as we don't want to close the connection immediately when returning # a stream - client = self._get_client_session(headers=headers) + session = self._get_client_session(headers=headers) try: - response = await client.post(url, json=json, data=data_as_binary, proxy=self.proxies) + response = await session.post(url, json=json, data=data_as_binary, proxy=self.proxies) response_error_payload = None if response.status != 200: try: @@ -211,18 +212,18 @@ def _rename_to_AsyncInferenceClient(code: str) -> str: pass response.raise_for_status() if stream: - return _async_yield_from(client, response) + return _async_yield_from(session, response) else: content = await response.read() - await client.close() + await session.close() return content except asyncio.TimeoutError as error: - await client.close() + await session.close() # Convert any `TimeoutError` to a `InferenceTimeoutError` raise InferenceTimeoutError(f"Inference call timed out: {url}") from error # type: ignore except aiohttp.ClientResponseError as error: error.response_error_payload = response_error_payload - await client.close() + await session.close() if response.status == 422 and task is not None: error.message += f". Make sure '{task}' task is supported by the model." if response.status == 503: @@ -244,8 +245,13 @@ def _rename_to_AsyncInferenceClient(code: str) -> str: continue raise error except Exception: - await client.close() - raise""" + await session.close() + raise + + def __del__(self): + async def _close_all_sessions(): + await asyncio.gather(*[session.close() for session in self._sessions]) + asyncio.run(_close_all_sessions())""" def _make_post_async(code: str) -> str: @@ -500,16 +506,36 @@ def _get_client_session(self, headers: Optional[Dict] = None) -> "ClientSession" client_headers.update(headers) # Return a new aiohttp ClientSession with correct settings. - return aiohttp.ClientSession( + session = aiohttp.ClientSession( headers=client_headers, cookies=self.cookies, timeout=aiohttp.ClientTimeout(self.timeout), trust_env=self.trust_env, ) + # Keep track of sessions to close them later + self._sessions.add(session) + + # Override the 'close' method to deregister the session when closed + session._close = session.close + + async def close_session(): + await session._close() + self._sessions.discard(session) + + session.close = close_session + return session + """ code = _add_before(code, "\n def _resolve_url(", client_session_code) + # Add self._sessions attribute in __init__ + code = _add_before( + code, + "\n def __repr__(self):\n", + "\n # Keep track of the sessions to close them properly\n self._sessions: Set['ClientSession']= set()", + ) + return code From 0d825d8305c7b7e9578ed75ae06115634bef402b Mon Sep 17 00:00:00 2001 From: Lucain Pouget Date: Fri, 30 Aug 2024 16:22:02 +0200 Subject: [PATCH 2/3] add proper .close method + tests --- .../inference/_generated/_async_client.py | 26 ++- .../test_http_session_correctly_closed.yaml | 159 ++++++++++++++++++ tests/test_inference_async_client.py | 49 ++++++ utils/generate_async_inference_client.py | 27 ++- 4 files changed, 255 insertions(+), 6 deletions(-) create mode 100644 tests/cassettes/test_http_session_correctly_closed.yaml diff --git a/src/huggingface_hub/inference/_generated/_async_client.py b/src/huggingface_hub/inference/_generated/_async_client.py index 342161afd5..c11249f181 100644 --- a/src/huggingface_hub/inference/_generated/_async_client.py +++ b/src/huggingface_hub/inference/_generated/_async_client.py @@ -334,11 +334,31 @@ async def post( await session.close() raise + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.close() + def __del__(self): - async def _close_all_sessions(): - await asyncio.gather(*[session.close() for session in self._sessions]) + if len(self._sessions) > 0: + logger.warning( + "Deleting 'AsyncInferenceClient' client but some sessions are still open. " + "This can happen if you've stopped streaming data from the server before the stream was complete. " + "To close the client properly, you must call `await client.close()` " + "or use an async context (e.g. `async with AsyncInferenceClient(): ...`." + ) - asyncio.run(_close_all_sessions()) + async def close(self): + """Close all open sessions. + + By default, 'aiohttp.ClientSession' objects are closed automatically when a call is completed. However, if you + are streaming data from the server and you stop before the stream is complete, you must call this method to + close the session properly. + + Another possibility is to use an async context (e.g. `async with AsyncInferenceClient(): ...`). + """ + await asyncio.gather(*[session.close() for session in self._sessions]) async def audio_classification( self, diff --git a/tests/cassettes/test_http_session_correctly_closed.yaml b/tests/cassettes/test_http_session_correctly_closed.yaml new file mode 100644 index 0000000000..46ec8fce34 --- /dev/null +++ b/tests/cassettes/test_http_session_correctly_closed.yaml @@ -0,0 +1,159 @@ +interactions: +- request: + body: null + headers: + user-agent: + - unknown/None; hf_hub/0.25.0.dev0; python/3.10.12; torch/2.4.0; tensorflow/2.17.0; + fastcore/1.5.23 + method: POST + uri: https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3.1-8B-Instruct + response: + body: + string: 'data: {"index":1,"token":{"id":279,"text":" the","logprob":-0.63378906,"special":false},"generated_text":" + the","details":null} + + + ' + headers: + Access-Control-Allow-Credentials: + - 'true' + Access-Control-Allow-Origin: + - '*' + Cache-Control: + - no-cache + Connection: + - keep-alive + Content-Type: + - text/event-stream + Date: + - Fri, 30 Aug 2024 13:58:50 GMT + Transfer-Encoding: + - chunked + Vary: + - origin, access-control-request-method, access-control-request-headers, Origin, + Access-Control-Request-Method, Access-Control-Request-Headers + x-accel-buffering: + - 'no' + x-compute-characters: + - '41' + x-compute-type: + - 2-a10-g + x-request-id: + - 8ZMGhj7Cj90UK12dSJhfj + x-sha: + - 5206a32e0bd3067aef1ce90f5528ade7d866253f + status: + code: 200 + message: OK +- request: + body: null + headers: + user-agent: + - unknown/None; hf_hub/0.25.0.dev0; python/3.10.12; torch/2.4.0; tensorflow/2.17.0; + fastcore/1.5.23 + method: POST + uri: https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3.1-8B-Instruct + response: + body: + string: 'data: {"index":1,"token":{"id":279,"text":" the","logprob":-0.63378906,"special":false},"generated_text":" + the","details":null} + + + ' + headers: + Access-Control-Allow-Credentials: + - 'true' + Connection: + - keep-alive + Content-Length: + - '129' + Content-Type: + - text/event-stream + Date: + - Fri, 30 Aug 2024 13:58:51 GMT + Vary: + - Origin, Access-Control-Request-Method, Access-Control-Request-Headers + x-compute-type: + - cache + x-request-id: + - gG5snagDwsUuzxQUSHpfs + x-sha: + - 5206a32e0bd3067aef1ce90f5528ade7d866253f + status: + code: 200 + message: OK +- request: + body: null + headers: + user-agent: + - unknown/None; hf_hub/0.25.0.dev0; python/3.10.12; torch/2.4.0; tensorflow/2.17.0; + fastcore/1.5.23 + method: POST + uri: https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3.1-8B-Instruct + response: + body: + string: 'data: {"index":1,"token":{"id":279,"text":" the","logprob":-0.63378906,"special":false},"generated_text":" + the","details":null} + + + ' + headers: + Access-Control-Allow-Credentials: + - 'true' + Connection: + - keep-alive + Content-Length: + - '129' + Content-Type: + - text/event-stream + Date: + - Fri, 30 Aug 2024 13:58:51 GMT + Vary: + - Origin, Access-Control-Request-Method, Access-Control-Request-Headers + x-compute-type: + - cache + x-request-id: + - 8PS-J7-J2QKF0FyOb9War + x-sha: + - 5206a32e0bd3067aef1ce90f5528ade7d866253f + status: + code: 200 + message: OK +- request: + body: null + headers: + user-agent: + - unknown/None; hf_hub/0.25.0.dev0; python/3.10.12; torch/2.4.0; tensorflow/2.17.0; + fastcore/1.5.23 + method: POST + uri: https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3.1-8B-Instruct + response: + body: + string: 'data: {"index":1,"token":{"id":279,"text":" the","logprob":-0.63378906,"special":false},"generated_text":" + the","details":null} + + + ' + headers: + Access-Control-Allow-Credentials: + - 'true' + Connection: + - keep-alive + Content-Length: + - '129' + Content-Type: + - text/event-stream + Date: + - Fri, 30 Aug 2024 13:58:52 GMT + Vary: + - Origin, Access-Control-Request-Method, Access-Control-Request-Headers + x-compute-type: + - cache + x-request-id: + - _MFpbB44jCKcoM8wai66z + x-sha: + - 5206a32e0bd3067aef1ce90f5528ade7d866253f + status: + code: 200 + message: OK +version: 1 diff --git a/tests/test_inference_async_client.py b/tests/test_inference_async_client.py index fab382fa64..13922e4f89 100644 --- a/tests/test_inference_async_client.py +++ b/tests/test_inference_async_client.py @@ -409,3 +409,52 @@ async def test_openai_compatibility_with_stream_true(): chunked_text = [chunk.choices[0].delta.content async for chunk in output] assert len(chunked_text) == 34 assert "".join(chunked_text) == "Here it goes:\n\n1, 2, 3, 4, 5, 6, 7, 8, 9, 10!" + + +@pytest.mark.vcr +@pytest.mark.asyncio +@with_production_testing +async def test_http_session_correctly_closed() -> None: + """ + Regression test for #2493. + Async client should close the HTTP session after the request is done. + This is always done except for streamed responses if the stream is not fully consumed. + Fixed by keeping a list of sessions and closing them all when deleting the client. + + See https://github.com/huggingface/huggingface_hub/issues/2493. + """ + + client = AsyncInferenceClient("meta-llama/Meta-Llama-3.1-8B-Instruct") + kwargs = {"prompt": "Hi", "stream": True, "max_new_tokens": 1} + + # Test create session + close it + check correctly unregistered + await client.text_generation(**kwargs) + assert len(client._sessions) == 1 + await list(client._sessions)[0].close() + assert len(client._sessions) == 0 + + # Test create multiple sessions + close AsyncInferenceClient + check correctly unregistered + await client.text_generation(**kwargs) + await client.text_generation(**kwargs) + await client.text_generation(**kwargs) + + assert len(client._sessions) == 3 + await client.close() + assert len(client._sessions) == 0 + + +@pytest.mark.asyncio +async def test_use_async_with_inference_client(): + with patch("huggingface_hub.AsyncInferenceClient.close") as mock_close: + async with AsyncInferenceClient(): + pass + mock_close.assert_called_once() + + +@pytest.mark.asyncio +async def test_warns_if_client_deleted_with_opened_sessions(): + client = AsyncInferenceClient() + session = client._get_client_session() + with pytest.warns(UserWarning): + client.__del__() + await session.close() diff --git a/utils/generate_async_inference_client.py b/utils/generate_async_inference_client.py index b4ad9a2d07..ae939e47dc 100644 --- a/utils/generate_async_inference_client.py +++ b/utils/generate_async_inference_client.py @@ -248,10 +248,31 @@ def _rename_to_AsyncInferenceClient(code: str) -> str: await session.close() raise + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.close() + def __del__(self): - async def _close_all_sessions(): - await asyncio.gather(*[session.close() for session in self._sessions]) - asyncio.run(_close_all_sessions())""" + if len(self._sessions) > 0: + logger.warning( + "Deleting 'AsyncInferenceClient' client but some sessions are still open. " + "This can happen if you've stopped streaming data from the server before the stream was complete. " + "To close the client properly, you must call `await client.close()` " + "or use an async context (e.g. `async with AsyncInferenceClient(): ...`." + ) + + async def close(self): + \"""Close all open sessions. + + By default, 'aiohttp.ClientSession' objects are closed automatically when a call is completed. However, if you + are streaming data from the server and you stop before the stream is complete, you must call this method to + close the session properly. + + Another possibility is to use an async context (e.g. `async with AsyncInferenceClient(): ...`). + \""" + await asyncio.gather(*[session.close() for session in self._sessions])""" def _make_post_async(code: str) -> str: From 9ca90f29a4354b3ad04393e236c75fb79294474e Mon Sep 17 00:00:00 2001 From: Lucain Pouget Date: Fri, 30 Aug 2024 16:42:03 +0200 Subject: [PATCH 3/3] proper warning --- src/huggingface_hub/inference/_generated/_async_client.py | 2 +- utils/generate_async_inference_client.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/huggingface_hub/inference/_generated/_async_client.py b/src/huggingface_hub/inference/_generated/_async_client.py index c11249f181..fd7343ea09 100644 --- a/src/huggingface_hub/inference/_generated/_async_client.py +++ b/src/huggingface_hub/inference/_generated/_async_client.py @@ -342,7 +342,7 @@ async def __aexit__(self, exc_type, exc_value, traceback): def __del__(self): if len(self._sessions) > 0: - logger.warning( + warnings.warn( "Deleting 'AsyncInferenceClient' client but some sessions are still open. " "This can happen if you've stopped streaming data from the server before the stream was complete. " "To close the client properly, you must call `await client.close()` " diff --git a/utils/generate_async_inference_client.py b/utils/generate_async_inference_client.py index ae939e47dc..4d607f7f8c 100644 --- a/utils/generate_async_inference_client.py +++ b/utils/generate_async_inference_client.py @@ -256,7 +256,7 @@ async def __aexit__(self, exc_type, exc_value, traceback): def __del__(self): if len(self._sessions) > 0: - logger.warning( + warnings.warn( "Deleting 'AsyncInferenceClient' client but some sessions are still open. " "This can happen if you've stopped streaming data from the server before the stream was complete. " "To close the client properly, you must call `await client.close()` "