From 8ea3f1d16eb225e6f46308e568f907be2b699ee2 Mon Sep 17 00:00:00 2001 From: Lucain Pouget Date: Fri, 7 Jun 2024 15:53:26 +0200 Subject: [PATCH] Add X-wait-for-model only after first call --- src/huggingface_hub/inference/_client.py | 4 ++-- src/huggingface_hub/inference/_generated/_async_client.py | 4 ++-- utils/generate_async_inference_client.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/huggingface_hub/inference/_client.py b/src/huggingface_hub/inference/_client.py index b0366634bc..d09177249e 100644 --- a/src/huggingface_hub/inference/_client.py +++ b/src/huggingface_hub/inference/_client.py @@ -250,8 +250,6 @@ def post( headers = self.headers.copy() if task in TASKS_EXPECTING_IMAGES and "Accept" not in headers: headers["Accept"] = "image/png" - if "X-wait-for-model" not in headers and url.startswith(INFERENCE_ENDPOINT): - headers["X-wait-for-model"] = "1" t0 = time.time() timeout = self.timeout @@ -291,6 +289,8 @@ def post( # ...or wait 1s and retry logger.info(f"Waiting for model to be loaded on the server: {error}") time.sleep(1) + if "X-wait-for-model" not in headers and url.startswith(INFERENCE_ENDPOINT): + headers["X-wait-for-model"] = "1" if timeout is not None: timeout = max(self.timeout - (time.time() - t0), 1) # type: ignore continue diff --git a/src/huggingface_hub/inference/_generated/_async_client.py b/src/huggingface_hub/inference/_generated/_async_client.py index 90af4d8505..d7571de176 100644 --- a/src/huggingface_hub/inference/_generated/_async_client.py +++ b/src/huggingface_hub/inference/_generated/_async_client.py @@ -238,8 +238,6 @@ async def post( headers = self.headers.copy() if task in TASKS_EXPECTING_IMAGES and "Accept" not in headers: headers["Accept"] = "image/png" - if "X-wait-for-model" not in headers and url.startswith(INFERENCE_ENDPOINT): - headers["X-wait-for-model"] = "1" t0 = time.time() timeout = self.timeout @@ -286,6 +284,8 @@ async def post( ) from error # ...or wait 1s and retry logger.info(f"Waiting for model to be loaded on the server: {error}") + if "X-wait-for-model" not in headers and url.startswith(INFERENCE_ENDPOINT): + headers["X-wait-for-model"] = "1" time.sleep(1) if timeout is not None: timeout = max(self.timeout - (time.time() - t0), 1) # type: ignore diff --git a/utils/generate_async_inference_client.py b/utils/generate_async_inference_client.py index 7147b66c24..eaa0d7f04d 100644 --- a/utils/generate_async_inference_client.py +++ b/utils/generate_async_inference_client.py @@ -186,8 +186,6 @@ def _rename_to_AsyncInferenceClient(code: str) -> str: headers = self.headers.copy() if task in TASKS_EXPECTING_IMAGES and "Accept" not in headers: headers["Accept"] = "image/png" - if "X-wait-for-model" not in headers and url.startswith(INFERENCE_ENDPOINT): - headers["X-wait-for-model"] = "1" t0 = time.time() timeout = self.timeout @@ -234,6 +232,8 @@ def _rename_to_AsyncInferenceClient(code: str) -> str: ) from error # ...or wait 1s and retry logger.info(f"Waiting for model to be loaded on the server: {error}") + if "X-wait-for-model" not in headers and url.startswith(INFERENCE_ENDPOINT): + headers["X-wait-for-model"] = "1" time.sleep(1) if timeout is not None: timeout = max(self.timeout - (time.time() - t0), 1) # type: ignore