Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not raise on .resume() if Inference Endpoint is already running #2335

Merged
merged 3 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/huggingface_hub/_inference_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,16 +315,23 @@ def pause(self) -> "InferenceEndpoint":
self._populate_from_raw()
return self

def resume(self) -> "InferenceEndpoint":
def resume(self, running_ok: bool = True) -> "InferenceEndpoint":
"""Resume the Inference Endpoint.

This is an alias for [`HfApi.resume_inference_endpoint`]. The current object is mutated in place with the
latest data from the server.

Args:
running_ok (`bool`, *optional*):
If `True`, the method will not raise an error if the Inference Endpoint is already running. Defaults to
`True`.

Returns:
[`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data.
"""
obj = self._api.resume_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token) # type: ignore [arg-type]
obj = self._api.resume_inference_endpoint(
name=self.name, namespace=self.namespace, running_ok=running_ok, token=self._token
) # type: ignore [arg-type]
self.raw = obj.raw
self._populate_from_raw()
return self
Expand Down
19 changes: 17 additions & 2 deletions src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7489,7 +7489,12 @@ def pause_inference_endpoint(
return InferenceEndpoint.from_raw(response.json(), namespace=namespace, token=token)

def resume_inference_endpoint(
self, name: str, *, namespace: Optional[str] = None, token: Union[bool, str, None] = None
self,
name: str,
*,
namespace: Optional[str] = None,
running_ok: bool = True,
token: Union[bool, str, None] = None,
) -> InferenceEndpoint:
"""Resume an Inference Endpoint.

Expand All @@ -7500,6 +7505,9 @@ def resume_inference_endpoint(
The name of the Inference Endpoint to resume.
namespace (`str`, *optional*):
The namespace in which the Inference Endpoint is located. Defaults to the current user.
running_ok (`bool`, *optional*):
If `True`, the method will not raise an error if the Inference Endpoint is already running. Defaults to
`True`.
token (Union[bool, str, None], optional):
A valid user access token (string). Defaults to the locally saved
token, which is the recommended method for authentication (see
Expand All @@ -7515,7 +7523,14 @@ def resume_inference_endpoint(
f"{INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}/resume",
headers=self._build_hf_headers(token=token),
)
hf_raise_for_status(response)
try:
hf_raise_for_status(response)
except HfHubHTTPError as error:
# If already running (and it's ok), then fetch current status and return
if running_ok and error.response.status_code == 400 and "already running" in error.response.text:
return self.get_inference_endpoint(name, namespace=namespace, token=token)
# Otherwise, raise the error
raise

return InferenceEndpoint.from_raw(response.json(), namespace=namespace, token=token)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_inference_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,4 +256,4 @@ def test_resume(mock: Mock):
endpoint = InferenceEndpoint.from_raw(MOCK_RUNNING, namespace="foo")
mock.return_value = InferenceEndpoint.from_raw(MOCK_INITIALIZING, namespace="foo")
endpoint.resume()
mock.assert_called_once_with(namespace="foo", name="my-endpoint-name", token=None)
mock.assert_called_once_with(namespace="foo", name="my-endpoint-name", token=None, running_ok=True)
Loading