From f917269b35b6582aecabd7a75610b2225407ae1f Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Fri, 22 Nov 2024 16:48:27 -0800 Subject: [PATCH] feat: Add sdk support to inference timeout on cloud-based endpoints (dedicated or PSC). PiperOrigin-RevId: 699325577 --- google/cloud/aiplatform/models.py | 31 ++++++++++++ tests/unit/aiplatform/test_endpoints.py | 67 +++++++++++++++++++++++++ 2 files changed, 98 insertions(+) diff --git a/google/cloud/aiplatform/models.py b/google/cloud/aiplatform/models.py index bbc7d9ab72..e3773712b9 100644 --- a/google/cloud/aiplatform/models.py +++ b/google/cloud/aiplatform/models.py @@ -783,6 +783,7 @@ def create( request_response_logging_sampling_rate: Optional[float] = None, request_response_logging_bq_destination_table: Optional[str] = None, dedicated_endpoint_enabled=False, + inference_timeout: Optional[int] = None, ) -> "Endpoint": """Creates a new endpoint. @@ -854,6 +855,8 @@ def create( Optional. If enabled, a dedicated dns will be created and your traffic will be fully isolated from other customers' traffic and latency will be reduced. + inference_timeout (int): + Optional. It defines the prediction timeout, in seconds, for online predictions using cloud-based endpoints. This applies to either PSC endpoints, when private_service_connect_config is set, or dedicated endpoints, when dedicated_endpoint_enabled is true. Returns: endpoint (aiplatform.Endpoint): @@ -882,6 +885,17 @@ def create( ), ) ) + + client_connection_config = None + if ( + inference_timeout is not None + and inference_timeout > 0 + and dedicated_endpoint_enabled + ): + client_connection_config = gca_endpoint_compat.ClientConnectionConfig( + inference_timeout=duration_pb2.Duration(seconds=inference_timeout) + ) + return cls._create( api_client=api_client, display_name=display_name, @@ -899,6 +913,7 @@ def create( endpoint_id=endpoint_id, predict_request_response_logging_config=predict_request_response_logging_config, dedicated_endpoint_enabled=dedicated_endpoint_enabled, + client_connection_config=client_connection_config, ) @classmethod @@ -925,6 +940,9 @@ def _create( gca_service_networking.PrivateServiceConnectConfig ] = None, dedicated_endpoint_enabled=False, + client_connection_config: Optional[ + gca_endpoint_compat.ClientConnectionConfig + ] = None, ) -> "Endpoint": """Creates a new endpoint by calling the API client. @@ -995,6 +1013,8 @@ def _create( Optional. If enabled, a dedicated dns will be created and your traffic will be fully isolated from other customers' traffic and latency will be reduced. + client_connection_config (aiplatform.endpoint.ClientConnectionConfig): + Optional. The inference timeout which is applied on cloud-based (PSC, or dedicated) endpoints for online prediction. Returns: endpoint (aiplatform.Endpoint): @@ -1014,6 +1034,7 @@ def _create( predict_request_response_logging_config=predict_request_response_logging_config, private_service_connect_config=private_service_connect_config, dedicated_endpoint_enabled=dedicated_endpoint_enabled, + client_connection_config=client_connection_config, ) operation_future = api_client.create_endpoint( @@ -3253,6 +3274,7 @@ def create( encryption_spec_key_name: Optional[str] = None, sync=True, private_service_connect_config: Optional[PrivateServiceConnectConfig] = None, + inference_timeout: Optional[int] = None, ) -> "PrivateEndpoint": """Creates a new PrivateEndpoint. @@ -3338,6 +3360,8 @@ def create( private_service_connect_config (aiplatform.PrivateEndpoint.PrivateServiceConnectConfig): [Private Service Connect](https://cloud.google.com/vpc/docs/private-service-connect) configuration for the endpoint. Cannot be set when network is specified. + inference_timeout (int): + Optional. It defines the prediction timeout, in seconds, for online predictions using cloud-based endpoints. This applies to either PSC endpoints, when private_service_connect_config is set, or dedicated endpoints, when dedicated_endpoint_enabled is true. Returns: endpoint (aiplatform.PrivateEndpoint): @@ -3374,6 +3398,12 @@ def create( private_service_connect_config._gapic_private_service_connect_config ) + client_connection_config = None + if private_service_connect_config and inference_timeout: + client_connection_config = gca_endpoint_compat.ClientConnectionConfig( + inference_timeout=duration_pb2.Duration(seconds=inference_timeout) + ) + return cls._create( api_client=api_client, display_name=display_name, @@ -3388,6 +3418,7 @@ def create( network=network, sync=sync, private_service_connect_config=config, + client_connection_config=client_connection_config, ) @classmethod diff --git a/tests/unit/aiplatform/test_endpoints.py b/tests/unit/aiplatform/test_endpoints.py index ee8b222c1c..82bdaa3f52 100644 --- a/tests/unit/aiplatform/test_endpoints.py +++ b/tests/unit/aiplatform/test_endpoints.py @@ -21,6 +21,7 @@ import json import requests from unittest import mock +from google.protobuf import duration_pb2 from google.api_core import operation as ga_operation from google.auth import credentials as auth_credentials @@ -269,6 +270,11 @@ ) ) +_TEST_INFERENCE_TIMEOUT = 100 +_TEST_CLIENT_CONNECTION_CONFIG = gca_endpoint.ClientConnectionConfig( + inference_timeout=duration_pb2.Duration(seconds=_TEST_INFERENCE_TIMEOUT) +) + """ ---------------------------------------------------------------------------- Endpoint Fixtures @@ -1258,6 +1264,34 @@ def test_create_dedicated_endpoint(self, create_dedicated_endpoint_mock, sync): endpoint_id=None, ) + @pytest.mark.parametrize("sync", [True, False]) + def test_create_dedicated_endpoint_with_timeout( + self, create_dedicated_endpoint_mock, sync + ): + my_endpoint = models.Endpoint.create( + display_name=_TEST_DISPLAY_NAME, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + dedicated_endpoint_enabled=True, + sync=sync, + inference_timeout=_TEST_INFERENCE_TIMEOUT, + ) + if not sync: + my_endpoint.wait() + + expected_endpoint = gca_endpoint.Endpoint( + display_name=_TEST_DISPLAY_NAME, + dedicated_endpoint_enabled=True, + client_connection_config=_TEST_CLIENT_CONNECTION_CONFIG, + ) + create_dedicated_endpoint_mock.assert_called_once_with( + parent=_TEST_PARENT, + endpoint=expected_endpoint, + metadata=(), + timeout=None, + endpoint_id=None, + ) + @pytest.mark.usefixtures("get_empty_endpoint_mock") def test_accessing_properties_with_no_resource_raises( self, @@ -3441,6 +3475,39 @@ def test_create_psc(self, create_psc_private_endpoint_mock, sync): endpoint_id=None, ) + @pytest.mark.parametrize("sync", [True, False]) + def test_create_psc_with_timeout(self, create_psc_private_endpoint_mock, sync): + test_endpoint = models.PrivateEndpoint.create( + display_name=_TEST_DISPLAY_NAME, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + private_service_connect_config=models.PrivateEndpoint.PrivateServiceConnectConfig( + project_allowlist=_TEST_PROJECT_ALLOWLIST + ), + sync=sync, + inference_timeout=_TEST_INFERENCE_TIMEOUT, + ) + + if not sync: + test_endpoint.wait() + + expected_endpoint = gca_endpoint.Endpoint( + display_name=_TEST_DISPLAY_NAME, + private_service_connect_config=gca_service_networking.PrivateServiceConnectConfig( + enable_private_service_connect=True, + project_allowlist=_TEST_PROJECT_ALLOWLIST, + ), + client_connection_config=_TEST_CLIENT_CONNECTION_CONFIG, + ) + + create_psc_private_endpoint_mock.assert_called_once_with( + parent=_TEST_PARENT, + endpoint=expected_endpoint, + metadata=(), + timeout=None, + endpoint_id=None, + ) + @pytest.mark.usefixtures("get_psa_private_endpoint_with_model_mock") def test_psa_predict(self, predict_private_endpoint_mock): test_endpoint = models.PrivateEndpoint(_TEST_ID)