Skip to content

Commit

Permalink
feat: Added async prediction and explanation support to the `Endpoint…
Browse files Browse the repository at this point in the history
…` class

* Added the `Endpoint.predict_async` method
* Added the `Endpoint.explain_async` method
* Made it possible to use async clients in classes derived from `VertexAiResourceNounWithFutureManager` that use `@optional_sync`.

PiperOrigin-RevId: 565472250
  • Loading branch information
Ark-kun authored and copybara-github committed Sep 14, 2023
1 parent 8b0add1 commit e9eb159
Show file tree
Hide file tree
Showing 6 changed files with 295 additions and 6 deletions.
6 changes: 6 additions & 0 deletions google/cloud/aiplatform/compat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
services.model_garden_service_client = services.model_garden_service_client_v1beta1
services.pipeline_service_client = services.pipeline_service_client_v1beta1
services.prediction_service_client = services.prediction_service_client_v1beta1
services.prediction_service_async_client = (
services.prediction_service_async_client_v1beta1
)
services.schedule_service_client = services.schedule_service_client_v1beta1
services.specialist_pool_service_client = (
services.specialist_pool_service_client_v1beta1
Expand Down Expand Up @@ -144,6 +147,9 @@
services.model_service_client = services.model_service_client_v1
services.pipeline_service_client = services.pipeline_service_client_v1
services.prediction_service_client = services.prediction_service_client_v1
services.prediction_service_async_client = (
services.prediction_service_async_client_v1
)
services.schedule_service_client = services.schedule_service_client_v1
services.specialist_pool_service_client = services.specialist_pool_service_client_v1
services.tensorboard_service_client = services.tensorboard_service_client_v1
Expand Down
8 changes: 8 additions & 0 deletions google/cloud/aiplatform/compat/services/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@
from google.cloud.aiplatform_v1beta1.services.prediction_service import (
client as prediction_service_client_v1beta1,
)
from google.cloud.aiplatform_v1beta1.services.prediction_service import (
async_client as prediction_service_async_client_v1beta1,
)
from google.cloud.aiplatform_v1beta1.services.schedule_service import (
client as schedule_service_client_v1beta1,
)
Expand Down Expand Up @@ -109,6 +112,9 @@
from google.cloud.aiplatform_v1.services.prediction_service import (
client as prediction_service_client_v1,
)
from google.cloud.aiplatform_v1.services.prediction_service import (
async_client as prediction_service_async_client_v1,
)
from google.cloud.aiplatform_v1.services.schedule_service import (
client as schedule_service_client_v1,
)
Expand Down Expand Up @@ -136,6 +142,7 @@
model_service_client_v1,
pipeline_service_client_v1,
prediction_service_client_v1,
prediction_service_async_client_v1,
schedule_service_client_v1,
specialist_pool_service_client_v1,
tensorboard_service_client_v1,
Expand All @@ -155,6 +162,7 @@
persistent_resource_service_client_v1beta1,
pipeline_service_client_v1beta1,
prediction_service_client_v1beta1,
prediction_service_async_client_v1beta1,
schedule_service_client_v1beta1,
specialist_pool_service_client_v1beta1,
metadata_service_client_v1beta1,
Expand Down
164 changes: 158 additions & 6 deletions google/cloud/aiplatform/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import json
import pathlib
import re
Expand Down Expand Up @@ -226,7 +227,10 @@ def __init__(
# Lazy load the Endpoint gca_resource until needed
self._gca_resource = gca_endpoint_compat.Endpoint(name=endpoint_name)

self._prediction_client = self._instantiate_prediction_client(
(
self._prediction_client,
self._prediction_async_client,
) = self._instantiate_prediction_clients(
location=self.location,
credentials=credentials,
)
Expand Down Expand Up @@ -572,7 +576,10 @@ def _construct_sdk_resource_from_gapic(
credentials=credentials,
)

endpoint._prediction_client = cls._instantiate_prediction_client(
(
endpoint._prediction_client,
endpoint._prediction_async_client,
) = cls._instantiate_prediction_clients(
location=endpoint.location,
credentials=credentials,
)
Expand Down Expand Up @@ -1384,10 +1391,12 @@ def _undeploy(
self._sync_gca_resource()

@staticmethod
def _instantiate_prediction_client(
def _instantiate_prediction_clients(
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> utils.PredictionClientWithOverride:
) -> Tuple[
utils.PredictionClientWithOverride, utils.PredictionAsyncClientWithOverride
]:
"""Helper method to instantiates prediction client with optional
overrides for this endpoint.
Expand All @@ -1399,14 +1408,34 @@ def _instantiate_prediction_client(
Returns:
prediction_client (prediction_service_client.PredictionServiceClient):
Initialized prediction client with optional overrides.
prediction_async_client (PredictionServiceAsyncClient):
Initialized prediction clients with optional overrides.
"""
return initializer.global_config.create_client(

# Creating an event loop if needed.
# PredictionServiceAsyncClient constructor calls `asyncio.get_event_loop`,
# which fails when there is no event loop (which does not exist by default
# in non-main threads in thread pool used when `sync=False`).
try:
asyncio.get_event_loop()
except RuntimeError:
asyncio.set_event_loop(asyncio.new_event_loop())

async_client = initializer.global_config.create_client(
client_class=utils.PredictionAsyncClientWithOverride,
credentials=credentials,
location_override=location,
prediction_client=True,
)
# We could use `client = async_client._client`, but then client would be
# a concrete `PredictionServiceClient`, not `PredictionClientWithOverride`.
client = initializer.global_config.create_client(
client_class=utils.PredictionClientWithOverride,
credentials=credentials,
location_override=location,
prediction_client=True,
)
return (client, async_client)

def update(
self,
Expand Down Expand Up @@ -1581,6 +1610,65 @@ def predict(
model_resource_name=prediction_response.model,
)

async def predict_async(
self,
instances: List,
*,
parameters: Optional[Dict] = None,
timeout: Optional[float] = None,
) -> Prediction:
"""Make an asynchronous prediction against this Endpoint.
Example usage:
```
response = await my_endpoint.predict_async(instances=[...])
my_predictions = response.predictions
```
Args:
instances (List):
Required. The instances that are the input to the
prediction call. A DeployedModel may have an upper limit
on the number of instances it supports per request, and
when it is exceeded the prediction call errors in case
of AutoML Models, or, in case of customer created
Models, the behaviour is as documented by that Model.
The schema of any single instance may be specified via
Endpoint's DeployedModels'
[Model's][google.cloud.aiplatform.v1beta1.DeployedModel.model]
[PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata]
``instance_schema_uri``.
parameters (Dict):
Optional. The parameters that govern the prediction. The schema of
the parameters may be specified via Endpoint's
DeployedModels' [Model's
][google.cloud.aiplatform.v1beta1.DeployedModel.model]
[PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata]
``parameters_schema_uri``.
timeout (float): Optional. The timeout for this request in seconds.
Returns:
prediction (aiplatform.Prediction):
Prediction with returned predictions and Model ID.
"""
self.wait()

prediction_response = await self._prediction_async_client.predict(
endpoint=self._gca_resource.name,
instances=instances,
parameters=parameters,
timeout=timeout,
)

return Prediction(
predictions=[
json_format.MessageToDict(item)
for item in prediction_response.predictions.pb
],
deployed_model_id=prediction_response.deployed_model_id,
model_version_id=prediction_response.model_version_id,
model_resource_name=prediction_response.model,
)

def raw_predict(
self, body: bytes, headers: Dict[str, str]
) -> requests.models.Response:
Expand Down Expand Up @@ -1676,6 +1764,70 @@ def explain(
explanations=explain_response.explanations,
)

async def explain_async(
self,
instances: List[Dict],
*,
parameters: Optional[Dict] = None,
deployed_model_id: Optional[str] = None,
timeout: Optional[float] = None,
) -> Prediction:
"""Make a prediction with explanations against this Endpoint.
Example usage:
```
response = await my_endpoint.explain_async(instances=[...])
my_explanations = response.explanations
```
Args:
instances (List):
Required. The instances that are the input to the
prediction call. A DeployedModel may have an upper limit
on the number of instances it supports per request, and
when it is exceeded the prediction call errors in case
of AutoML Models, or, in case of customer created
Models, the behaviour is as documented by that Model.
The schema of any single instance may be specified via
Endpoint's DeployedModels'
[Model's][google.cloud.aiplatform.v1beta1.DeployedModel.model]
[PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata]
``instance_schema_uri``.
parameters (Dict):
The parameters that govern the prediction. The schema of
the parameters may be specified via Endpoint's
DeployedModels' [Model's
][google.cloud.aiplatform.v1beta1.DeployedModel.model]
[PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata]
``parameters_schema_uri``.
deployed_model_id (str):
Optional. If specified, this ExplainRequest will be served by the
chosen DeployedModel, overriding this Endpoint's traffic split.
timeout (float): Optional. The timeout for this request in seconds.
Returns:
prediction (aiplatform.Prediction):
Prediction with returned predictions, explanations, and Model ID.
"""
self.wait()

explain_response = await self._prediction_async_client.explain(
endpoint=self.resource_name,
instances=instances,
parameters=parameters,
deployed_model_id=deployed_model_id,
timeout=timeout,
)

return Prediction(
predictions=[
json_format.MessageToDict(item)
for item in explain_response.predictions.pb
],
deployed_model_id=explain_response.deployed_model_id,
explanations=explain_response.explanations,
)

@classmethod
def list(
cls,
Expand Down
16 changes: 16 additions & 0 deletions google/cloud/aiplatform/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
model_service_client_v1beta1,
pipeline_service_client_v1beta1,
prediction_service_client_v1beta1,
prediction_service_async_client_v1beta1,
schedule_service_client_v1beta1,
tensorboard_service_client_v1beta1,
vizier_service_client_v1beta1,
Expand All @@ -68,6 +69,7 @@
model_service_client_v1,
pipeline_service_client_v1,
prediction_service_client_v1,
prediction_service_async_client_v1,
schedule_service_client_v1,
tensorboard_service_client_v1,
vizier_service_client_v1,
Expand All @@ -89,6 +91,7 @@
index_endpoint_service_client_v1beta1.IndexEndpointServiceClient,
model_service_client_v1beta1.ModelServiceClient,
prediction_service_client_v1beta1.PredictionServiceClient,
prediction_service_async_client_v1beta1.PredictionServiceAsyncClient,
pipeline_service_client_v1beta1.PipelineServiceClient,
job_service_client_v1beta1.JobServiceClient,
match_service_client_v1beta1.MatchServiceClient,
Expand All @@ -104,6 +107,7 @@
metadata_service_client_v1.MetadataServiceClient,
model_service_client_v1.ModelServiceClient,
prediction_service_client_v1.PredictionServiceClient,
prediction_service_async_client_v1.PredictionServiceAsyncClient,
pipeline_service_client_v1.PipelineServiceClient,
job_service_client_v1.JobServiceClient,
schedule_service_client_v1.ScheduleServiceClient,
Expand Down Expand Up @@ -616,6 +620,18 @@ class PredictionClientWithOverride(ClientWithOverride):
)


class PredictionAsyncClientWithOverride(ClientWithOverride):
_is_temporary = False
_default_version = compat.DEFAULT_VERSION
_version_map = (
(compat.V1, prediction_service_async_client_v1.PredictionServiceAsyncClient),
(
compat.V1BETA1,
prediction_service_async_client_v1beta1.PredictionServiceAsyncClient,
),
)


class MatchClientWithOverride(ClientWithOverride):
_is_temporary = False
_default_version = compat.V1BETA1
Expand Down
9 changes: 9 additions & 0 deletions tests/system/aiplatform/test_model_interactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#

import json
import pytest

from google.cloud import aiplatform

Expand Down Expand Up @@ -64,3 +65,11 @@ def test_prediction(self):
)
assert raw_prediction_response.status_code == 200
assert len(json.loads(raw_prediction_response.text)) == 1

@pytest.mark.asyncio
async def test_endpoint_predict_async(self):
# Test the Endpoint.predict_async method.
prediction_response = await self.endpoint.predict_async(
instances=[_PREDICTION_INSTANCE]
)
assert len(prediction_response.predictions) == 1
Loading

0 comments on commit e9eb159

Please sign in to comment.