Skip to content

Commit

Permalink
Add wait-for-model header when sending request to Inference API
Browse files Browse the repository at this point in the history
  • Loading branch information
Wauplin committed Jun 7, 2024
1 parent e43874a commit bf57460
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 1 deletion.
2 changes: 2 additions & 0 deletions src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,8 @@ def post(
headers = self.headers.copy()
if task in TASKS_EXPECTING_IMAGES and "Accept" not in headers:
headers["Accept"] = "image/png"
if "X-wait-for-model" not in headers and url.startswith(INFERENCE_ENDPOINT):
headers["X-wait-for-model"] = "1"

t0 = time.time()
timeout = self.timeout
Expand Down
2 changes: 2 additions & 0 deletions src/huggingface_hub/inference/_generated/_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,8 @@ async def post(
headers = self.headers.copy()
if task in TASKS_EXPECTING_IMAGES and "Accept" not in headers:
headers["Accept"] = "image/png"
if "X-wait-for-model" not in headers and url.startswith(INFERENCE_ENDPOINT):
headers["X-wait-for-model"] = "1"

t0 = time.time()
timeout = self.timeout
Expand Down
2 changes: 1 addition & 1 deletion tests/test_inference_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,7 @@ def test_mocked_post(self, get_session_mock: MagicMock) -> None:
"https://api-inference.huggingface.co/models/username/repo_name",
json=None,
data=b"content",
headers={"user-agent": expected_user_agent, "X-My-Header": "foo"},
headers={"user-agent": expected_user_agent, "X-My-Header": "foo", "X-wait-for-model": "1"},
cookies={"my-cookie": "bar"},
timeout=None,
stream=False,
Expand Down
2 changes: 2 additions & 0 deletions utils/generate_async_inference_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,8 @@ def _rename_to_AsyncInferenceClient(code: str) -> str:
headers = self.headers.copy()
if task in TASKS_EXPECTING_IMAGES and "Accept" not in headers:
headers["Accept"] = "image/png"
if "X-wait-for-model" not in headers and url.startswith(INFERENCE_ENDPOINT):
headers["X-wait-for-model"] = "1"
t0 = time.time()
timeout = self.timeout
Expand Down

0 comments on commit bf57460

Please sign in to comment.