diff --git a/src/huggingface_hub/_inference_endpoints.py b/src/huggingface_hub/_inference_endpoints.py index 291e283672..90b9d3248a 100644 --- a/src/huggingface_hub/_inference_endpoints.py +++ b/src/huggingface_hub/_inference_endpoints.py @@ -192,6 +192,12 @@ def wait(self, timeout: Optional[int] = None, refresh_every: int = 5) -> "Infere Returns: [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data. + + Raises: + [`InferenceEndpointError`] + If the Inference Endpoint ended up in a failed state. + [`InferenceEndpointTimeoutError`] + If the Inference Endpoint is not deployed after `timeout` seconds. """ if self.url is not None: # Means the endpoint is deployed logger.info("Inference Endpoint is ready to be used.") @@ -208,6 +214,10 @@ def wait(self, timeout: Optional[int] = None, refresh_every: int = 5) -> "Infere if self.url is not None: # Means the endpoint is deployed logger.info("Inference Endpoint is ready to be used.") return self + if self.status == InferenceEndpointStatus.FAILED: + raise InferenceEndpointError( + f"Inference Endpoint {self.name} failed to deploy. Please check the logs for more information." + ) if timeout is not None: if time.time() - start > timeout: raise InferenceEndpointTimeoutError("Timeout while waiting for Inference Endpoint to be deployed.") diff --git a/tests/test_inference_endpoints.py b/tests/test_inference_endpoints.py index 66da674072..319535ba77 100644 --- a/tests/test_inference_endpoints.py +++ b/tests/test_inference_endpoints.py @@ -76,6 +76,37 @@ }, } +MOCK_FAILED = { + "name": "my-endpoint-name", + "type": "protected", + "accountId": None, + "provider": {"vendor": "aws", "region": "us-east-1"}, + "compute": { + "accelerator": "cpu", + "instanceType": "c6i", + "instanceSize": "medium", + "scaling": {"minReplica": 0, "maxReplica": 1}, + }, + "model": { + "repository": "gpt2", + "revision": "11c5a3d5811f50298f278a704980280950aedb10", + "task": "text-generation", + "framework": "pytorch", + "image": {"huggingface": {}}, + }, + "status": { + "createdAt": "2023-10-26T12:41:53.263Z", + "createdBy": {"id": "6273f303f6d63a28483fde12", "name": "Wauplin"}, + "updatedAt": "2023-10-26T12:41:53.263Z", + "updatedBy": {"id": "6273f303f6d63a28483fde12", "name": "Wauplin"}, + "private": None, + "state": "failed", + "message": "Endpoint failed to deploy", + "readyReplica": 0, + "targetReplica": 1, + }, +} + def test_from_raw_initialization(): """Test InferenceEndpoint is correctly initialized from raw dict.""" @@ -188,6 +219,20 @@ def test_wait_timeout(mock_get: Mock): assert len(mock_get.call_args_list) == 3 +@patch("huggingface_hub.hf_api.HfApi.get_inference_endpoint") +def test_wait_failed(mock_get: Mock): + """Test waits until timeout error is raised.""" + endpoint = InferenceEndpoint.from_raw(MOCK_INITIALIZING, namespace="foo") + + mock_get.side_effect = [ + InferenceEndpoint.from_raw(MOCK_INITIALIZING, namespace="foo"), + InferenceEndpoint.from_raw(MOCK_INITIALIZING, namespace="foo"), + InferenceEndpoint.from_raw(MOCK_FAILED, namespace="foo"), + ] + with pytest.raises(InferenceEndpointError, match=".*failed to deploy.*"): + endpoint.wait(refresh_every=0.001) + + @patch("huggingface_hub.hf_api.HfApi.pause_inference_endpoint") def test_pause(mock: Mock): """Test `pause` calls the correct alias."""