Skip to content

Commit

Permalink
feat: Add sdk support to inference timeout on cloud-based endpoints (…
Browse files Browse the repository at this point in the history
…dedicated or PSC).

PiperOrigin-RevId: 699325577
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Nov 23, 2024
1 parent 1487846 commit f917269
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 0 deletions.
31 changes: 31 additions & 0 deletions google/cloud/aiplatform/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,7 @@ def create(
request_response_logging_sampling_rate: Optional[float] = None,
request_response_logging_bq_destination_table: Optional[str] = None,
dedicated_endpoint_enabled=False,
inference_timeout: Optional[int] = None,
) -> "Endpoint":
"""Creates a new endpoint.
Expand Down Expand Up @@ -854,6 +855,8 @@ def create(
Optional. If enabled, a dedicated dns will be created and your
traffic will be fully isolated from other customers' traffic and
latency will be reduced.
inference_timeout (int):
Optional. It defines the prediction timeout, in seconds, for online predictions using cloud-based endpoints. This applies to either PSC endpoints, when private_service_connect_config is set, or dedicated endpoints, when dedicated_endpoint_enabled is true.
Returns:
endpoint (aiplatform.Endpoint):
Expand Down Expand Up @@ -882,6 +885,17 @@ def create(
),
)
)

client_connection_config = None
if (
inference_timeout is not None
and inference_timeout > 0
and dedicated_endpoint_enabled
):
client_connection_config = gca_endpoint_compat.ClientConnectionConfig(
inference_timeout=duration_pb2.Duration(seconds=inference_timeout)
)

return cls._create(
api_client=api_client,
display_name=display_name,
Expand All @@ -899,6 +913,7 @@ def create(
endpoint_id=endpoint_id,
predict_request_response_logging_config=predict_request_response_logging_config,
dedicated_endpoint_enabled=dedicated_endpoint_enabled,
client_connection_config=client_connection_config,
)

@classmethod
Expand All @@ -925,6 +940,9 @@ def _create(
gca_service_networking.PrivateServiceConnectConfig
] = None,
dedicated_endpoint_enabled=False,
client_connection_config: Optional[
gca_endpoint_compat.ClientConnectionConfig
] = None,
) -> "Endpoint":
"""Creates a new endpoint by calling the API client.
Expand Down Expand Up @@ -995,6 +1013,8 @@ def _create(
Optional. If enabled, a dedicated dns will be created and your
traffic will be fully isolated from other customers' traffic and
latency will be reduced.
client_connection_config (aiplatform.endpoint.ClientConnectionConfig):
Optional. The inference timeout which is applied on cloud-based (PSC, or dedicated) endpoints for online prediction.
Returns:
endpoint (aiplatform.Endpoint):
Expand All @@ -1014,6 +1034,7 @@ def _create(
predict_request_response_logging_config=predict_request_response_logging_config,
private_service_connect_config=private_service_connect_config,
dedicated_endpoint_enabled=dedicated_endpoint_enabled,
client_connection_config=client_connection_config,
)

operation_future = api_client.create_endpoint(
Expand Down Expand Up @@ -3253,6 +3274,7 @@ def create(
encryption_spec_key_name: Optional[str] = None,
sync=True,
private_service_connect_config: Optional[PrivateServiceConnectConfig] = None,
inference_timeout: Optional[int] = None,
) -> "PrivateEndpoint":
"""Creates a new PrivateEndpoint.
Expand Down Expand Up @@ -3338,6 +3360,8 @@ def create(
private_service_connect_config (aiplatform.PrivateEndpoint.PrivateServiceConnectConfig):
[Private Service Connect](https://cloud.google.com/vpc/docs/private-service-connect) configuration for the endpoint.
Cannot be set when network is specified.
inference_timeout (int):
Optional. It defines the prediction timeout, in seconds, for online predictions using cloud-based endpoints. This applies to either PSC endpoints, when private_service_connect_config is set, or dedicated endpoints, when dedicated_endpoint_enabled is true.
Returns:
endpoint (aiplatform.PrivateEndpoint):
Expand Down Expand Up @@ -3374,6 +3398,12 @@ def create(
private_service_connect_config._gapic_private_service_connect_config
)

client_connection_config = None
if private_service_connect_config and inference_timeout:
client_connection_config = gca_endpoint_compat.ClientConnectionConfig(
inference_timeout=duration_pb2.Duration(seconds=inference_timeout)
)

return cls._create(
api_client=api_client,
display_name=display_name,
Expand All @@ -3388,6 +3418,7 @@ def create(
network=network,
sync=sync,
private_service_connect_config=config,
client_connection_config=client_connection_config,
)

@classmethod
Expand Down
67 changes: 67 additions & 0 deletions tests/unit/aiplatform/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import json
import requests
from unittest import mock
from google.protobuf import duration_pb2

from google.api_core import operation as ga_operation
from google.auth import credentials as auth_credentials
Expand Down Expand Up @@ -269,6 +270,11 @@
)
)

_TEST_INFERENCE_TIMEOUT = 100
_TEST_CLIENT_CONNECTION_CONFIG = gca_endpoint.ClientConnectionConfig(
inference_timeout=duration_pb2.Duration(seconds=_TEST_INFERENCE_TIMEOUT)
)

"""
----------------------------------------------------------------------------
Endpoint Fixtures
Expand Down Expand Up @@ -1258,6 +1264,34 @@ def test_create_dedicated_endpoint(self, create_dedicated_endpoint_mock, sync):
endpoint_id=None,
)

@pytest.mark.parametrize("sync", [True, False])
def test_create_dedicated_endpoint_with_timeout(
self, create_dedicated_endpoint_mock, sync
):
my_endpoint = models.Endpoint.create(
display_name=_TEST_DISPLAY_NAME,
project=_TEST_PROJECT,
location=_TEST_LOCATION,
dedicated_endpoint_enabled=True,
sync=sync,
inference_timeout=_TEST_INFERENCE_TIMEOUT,
)
if not sync:
my_endpoint.wait()

expected_endpoint = gca_endpoint.Endpoint(
display_name=_TEST_DISPLAY_NAME,
dedicated_endpoint_enabled=True,
client_connection_config=_TEST_CLIENT_CONNECTION_CONFIG,
)
create_dedicated_endpoint_mock.assert_called_once_with(
parent=_TEST_PARENT,
endpoint=expected_endpoint,
metadata=(),
timeout=None,
endpoint_id=None,
)

@pytest.mark.usefixtures("get_empty_endpoint_mock")
def test_accessing_properties_with_no_resource_raises(
self,
Expand Down Expand Up @@ -3441,6 +3475,39 @@ def test_create_psc(self, create_psc_private_endpoint_mock, sync):
endpoint_id=None,
)

@pytest.mark.parametrize("sync", [True, False])
def test_create_psc_with_timeout(self, create_psc_private_endpoint_mock, sync):
test_endpoint = models.PrivateEndpoint.create(
display_name=_TEST_DISPLAY_NAME,
project=_TEST_PROJECT,
location=_TEST_LOCATION,
private_service_connect_config=models.PrivateEndpoint.PrivateServiceConnectConfig(
project_allowlist=_TEST_PROJECT_ALLOWLIST
),
sync=sync,
inference_timeout=_TEST_INFERENCE_TIMEOUT,
)

if not sync:
test_endpoint.wait()

expected_endpoint = gca_endpoint.Endpoint(
display_name=_TEST_DISPLAY_NAME,
private_service_connect_config=gca_service_networking.PrivateServiceConnectConfig(
enable_private_service_connect=True,
project_allowlist=_TEST_PROJECT_ALLOWLIST,
),
client_connection_config=_TEST_CLIENT_CONNECTION_CONFIG,
)

create_psc_private_endpoint_mock.assert_called_once_with(
parent=_TEST_PARENT,
endpoint=expected_endpoint,
metadata=(),
timeout=None,
endpoint_id=None,
)

@pytest.mark.usefixtures("get_psa_private_endpoint_with_model_mock")
def test_psa_predict(self, predict_private_endpoint_mock):
test_endpoint = models.PrivateEndpoint(_TEST_ID)
Expand Down

0 comments on commit f917269

Please sign in to comment.